From 7a3ab1620a6b853f76fcdacdc35df308a255036e Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 26 Jul 2024 04:01:17 +1000
Subject: [PATCH 001/188] support load model by custom get_model_classes (#899)
* feature(mlx_lm): support load model by custom get classes
* rename the param
---
llms/mlx_lm/utils.py | 10 ++++--
llms/tests/test_utils_load_model.py | 50 +++++++++++++++++++++++++++++
2 files changed, 57 insertions(+), 3 deletions(-)
create mode 100644 llms/tests/test_utils_load_model.py
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 229ee238..cffa2a89 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
-from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@@ -355,6 +355,7 @@ def load_model(
model_path: Path,
lazy: bool = False,
model_config: dict = {},
+ get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
"""
Load and initialize the model from a given path.
@@ -364,8 +365,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
- model_config(dict, optional): Configuration parameters for the model.
+ model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
+ get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
+ A function that returns the model class and model args class given a config.
+ Defaults to the _get_classes function.
Returns:
nn.Module: The loaded and initialized model.
@@ -392,7 +396,7 @@ def load_model(
for wf in weight_files:
weights.update(mx.load(wf))
- model_class, model_args_class = _get_classes(config=config)
+ model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py
new file mode 100644
index 00000000..73ee1352
--- /dev/null
+++ b/llms/tests/test_utils_load_model.py
@@ -0,0 +1,50 @@
+import unittest
+from pathlib import Path
+
+import mlx.nn as nn
+from mlx_lm.models.qwen2 import Model as Qwen2Model
+from mlx_lm.utils import get_model_path, load_model
+
+HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
+
+
+class TestLoadModelCustomGetClasses(unittest.TestCase):
+
+ def test_load_model_with_custom_get_classes(self):
+ class CustomQwenModel(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.config = args
+ self.custom_attribute = "This is a custom model"
+
+ def load_weights(self, weights):
+ self.qwenWeights = weights
+
+ class CustomQwenConfig:
+ @classmethod
+ def from_dict(cls, config):
+ instance = cls()
+ for k, v in config.items():
+ setattr(instance, k, v)
+ return instance
+
+ def custom_get_classes(config):
+ return CustomQwenModel, CustomQwenConfig
+
+ model_path = get_model_path(HF_MODEL_PATH)
+ model = load_model(model_path, get_model_classes=custom_get_classes)
+
+ self.assertIsInstance(model, CustomQwenModel)
+ self.assertTrue(hasattr(model, "custom_attribute"))
+ self.assertEqual(model.custom_attribute, "This is a custom model")
+ self.assertTrue(hasattr(model, "qwenWeights"))
+
+ def test_load_model_with_default_get_classes(self):
+ model_path = get_model_path(HF_MODEL_PATH)
+ model = load_model(model_path)
+
+ self.assertIsInstance(model, Qwen2Model)
+
+
+if __name__ == "__main__":
+ unittest.main()
From 46da74fea2404eb987e8070d84f5745a124dbbc8 Mon Sep 17 00:00:00 2001
From: otriscon <165947759+otriscon@users.noreply.github.com>
Date: Thu, 25 Jul 2024 19:45:22 -0400
Subject: [PATCH 002/188] Unify attention mask in LLMs (#911)
* Unify attention mask creation in LLMs.
Currently, each model implementation in `mlx-examples/llms/models` has ad-hoc
code to create a mask for the attention mechanism. This usually takes the form:
```
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
```
This correctly creates a mask only if the input consists of more than one token.
But this code assumes the multi-token input is at the beginning of inference.
If, for example, we are evaluating multiple tokens because of speculative
decoding or prompt cache reuse, this mask will not have the correct shape and
and will cause the raising of an exception in the attention computation.
Some of the models correctly implement the mask creation with code like this:
```
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
```
This commit unifies the attention mask creation for all models with a new
function `create_attention_mask`, reducing code duplication and helping all
models support inference performance enhancements like those mentioned above.
* Allow batches in LLM key-value cache
The current implementation of the LLM key-value cache assumes that
the input batch is of size 1. Input batching (evaluating multiple
alterative inputs at the same time) can be a valuable tool for
speculative sampling and other techniques.
This change removes the hard-coded batch size from the code that
resizes the key-value cache.
* Simplify causal mask creation
Use the same codepath regardless of whether there's an offset or
not. Addresses [this comment](https://github.com/ml-explore/mlx-examples/pull/911#discussion_r1691459717).
* Use old-style type annotation to avoid linter error
---
llms/mlx_lm/models/base.py | 35 +++++++++++++++++++++++--------
llms/mlx_lm/models/cohere.py | 7 ++-----
llms/mlx_lm/models/dbrx.py | 8 ++-----
llms/mlx_lm/models/deepseek_v2.py | 8 ++-----
llms/mlx_lm/models/gemma.py | 7 ++-----
llms/mlx_lm/models/gemma2.py | 7 ++-----
llms/mlx_lm/models/gpt2.py | 7 ++-----
llms/mlx_lm/models/gpt_bigcode.py | 7 ++-----
llms/mlx_lm/models/gpt_neox.py | 9 ++------
llms/mlx_lm/models/internlm2.py | 7 ++-----
llms/mlx_lm/models/llama.py | 9 ++------
llms/mlx_lm/models/minicpm.py | 7 ++-----
llms/mlx_lm/models/mixtral.py | 8 ++-----
llms/mlx_lm/models/olmo.py | 7 ++-----
llms/mlx_lm/models/openelm.py | 7 ++-----
llms/mlx_lm/models/phi.py | 10 ++++-----
llms/mlx_lm/models/phi3.py | 7 ++-----
llms/mlx_lm/models/phi3small.py | 7 ++-----
llms/mlx_lm/models/phixtral.py | 6 ++----
llms/mlx_lm/models/plamo.py | 7 ++-----
llms/mlx_lm/models/qwen.py | 8 ++-----
llms/mlx_lm/models/qwen2.py | 7 ++-----
llms/mlx_lm/models/qwen2_moe.py | 7 ++-----
llms/mlx_lm/models/stablelm.py | 8 ++-----
llms/mlx_lm/models/starcoder2.py | 7 ++-----
25 files changed, 76 insertions(+), 138 deletions(-)
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index 8c3ecc78..3fe276d2 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -1,14 +1,9 @@
import inspect
from dataclasses import dataclass
+from typing import List, Optional
import mlx.core as mx
-
-
-def create_additive_causal_mask(N: int, offset: int = 0):
- rinds = mx.arange(offset + N)
- linds = mx.arange(offset, offset + N) if offset else rinds
- mask = linds[:, None] < rinds[None]
- return mask * -1e9
+import mlx.nn as nn
class KVCache:
@@ -29,9 +24,10 @@ class KVCache:
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
+ B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
- k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
- v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
+ k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
+ v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
@@ -60,3 +56,24 @@ class BaseModelArgs:
if k in inspect.signature(cls).parameters
}
)
+
+
+def create_additive_causal_mask(N: int, offset: int = 0):
+ rinds = mx.arange(offset + N)
+ linds = mx.arange(offset, offset + N) if offset else rinds
+ mask = linds[:, None] < rinds[None]
+ return mask * -1e9
+
+
+def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
+ T = h.shape[1]
+ if T > 1:
+ # Input consists of multiple tokens, create a causal mask so that prior
+ # tokens do not give attention to later tokens. If a cache is in place
+ # (because e.g. prompt reuse), offset the mask accordingly.
+ offset = cache[0].offset if cache is not None and cache[0] is not None else 0
+ mask = create_additive_causal_mask(T, offset)
+ mask = mask.astype(h.dtype)
+ else:
+ mask = None
+ return mask
diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py
index 621a85a2..7dc2b9bf 100644
--- a/llms/mlx_lm/models/cohere.py
+++ b/llms/mlx_lm/models/cohere.py
@@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -157,10 +157,7 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py
index dc310ca4..7a2a7a7d 100644
--- a/llms/mlx_lm/models/dbrx.py
+++ b/llms/mlx_lm/models/dbrx.py
@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -199,11 +199,7 @@ class DBRX(nn.Module):
):
h = self.wte(inputs)
- mask = None
- T = h.shape[1]
- if T > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py
index 308b94ba..bd743e53 100644
--- a/llms/mlx_lm/models/deepseek_v2.py
+++ b/llms/mlx_lm/models/deepseek_v2.py
@@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache
+from .base import BaseModelArgs, KVCache, create_attention_mask
from .switch_layers import SwitchGLU
@@ -408,11 +408,7 @@ class DeepseekV2Model(nn.Module):
cache: Optional[KVCache] = None,
) -> mx.array:
h = self.embed_tokens(x)
- mask = None
- T = h.shape[1]
- if T > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py
index e48f1909..323ebaa6 100644
--- a/llms/mlx_lm/models/gemma.py
+++ b/llms/mlx_lm/models/gemma.py
@@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -141,10 +141,7 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py
index 1ab403da..d4bd8a5d 100644
--- a/llms/mlx_lm/models/gemma2.py
+++ b/llms/mlx_lm/models/gemma2.py
@@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -165,10 +165,7 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py
index ece7b6ec..81f71cac 100644
--- a/llms/mlx_lm/models/gpt2.py
+++ b/llms/mlx_lm/models/gpt2.py
@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .base import BaseModelArgs, create_additive_causal_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -136,10 +136,7 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
- mask = create_additive_causal_mask(
- hidden_states.shape[1], cache[0].offset if cache is not None else 0
- )
- mask = mask.astype(hidden_states.dtype)
+ mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py
index 20af3d0b..a5336203 100644
--- a/llms/mlx_lm/models/gpt_bigcode.py
+++ b/llms/mlx_lm/models/gpt_bigcode.py
@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .base import BaseModelArgs, create_additive_causal_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -147,10 +147,7 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
- mask = create_additive_causal_mask(
- hidden_states.shape[1], cache[0].offset if cache is not None else 0
- )
- mask = mask.astype(hidden_states.dtype)
+ mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py
index 9549f322..1d2f74b7 100644
--- a/llms/mlx_lm/models/gpt_neox.py
+++ b/llms/mlx_lm/models/gpt_neox.py
@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .base import BaseModelArgs, create_additive_causal_mask
+from .base import BaseModelArgs, create_attention_mask
# Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -150,12 +150,7 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.embed_in(inputs)
- mask = None
- if hidden_states.shape[1] > 1:
- mask = create_additive_causal_mask(
- hidden_states.shape[1], cache[0].offset if cache is not None else 0
- )
- mask = mask.astype(hidden_states.dtype)
+ mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py
index fc1cfc03..2ee2af2d 100644
--- a/llms/mlx_lm/models/internlm2.py
+++ b/llms/mlx_lm/models/internlm2.py
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -195,10 +195,7 @@ class InternLM2Model(nn.Module):
):
h = self.tok_embeddings(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py
index a697e2b7..2f323245 100644
--- a/llms/mlx_lm/models/llama.py
+++ b/llms/mlx_lm/models/llama.py
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_additive_causal_mask
+from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@@ -271,12 +271,7 @@ class LlamaModel(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = create_additive_causal_mask(
- h.shape[1], cache[0].offset if cache is not None else 0
- )
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py
index dbfe4186..a3d01cbb 100644
--- a/llms/mlx_lm/models/minicpm.py
+++ b/llms/mlx_lm/models/minicpm.py
@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -160,10 +160,7 @@ class MiniCPMModel(nn.Module):
):
h = self.embed_tokens(inputs) * self.args.scale_emb
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py
index 7d1b10ac..c7d8c5c5 100644
--- a/llms/mlx_lm/models/mixtral.py
+++ b/llms/mlx_lm/models/mixtral.py
@@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -164,11 +164,7 @@ class MixtralModel(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- T = h.shape[1]
- if T > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py
index 120ea9b9..8a28ad74 100644
--- a/llms/mlx_lm/models/olmo.py
+++ b/llms/mlx_lm/models/olmo.py
@@ -5,7 +5,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
try:
import hf_olmo
@@ -126,10 +126,7 @@ class Transformer(nn.Module):
):
h = self.wte(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py
index 3fbdc58c..3f0d2605 100644
--- a/llms/mlx_lm/models/openelm.py
+++ b/llms/mlx_lm/models/openelm.py
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -180,10 +180,7 @@ class OpenELMModel(nn.Module):
):
h = self.token_embeddings(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py
index 8feaa23a..520ac1ad 100644
--- a/llms/mlx_lm/models/phi.py
+++ b/llms/mlx_lm/models/phi.py
@@ -5,7 +5,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -138,14 +138,12 @@ class PhiModel(nn.Module):
def __call__(self, x, cache):
x = self.embed_tokens(x)
+
+ mask = create_attention_mask(x, cache)
+
if cache is None:
cache = [None] * len(self.layers)
- mask = None
- if x.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
- mask = mask.astype(x.dtype)
-
for layer, c in zip(self.layers, cache):
x = layer(x, mask, c)
return self.final_layernorm(x)
diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py
index dd2d6d82..2536aacb 100644
--- a/llms/mlx_lm/models/phi3.py
+++ b/llms/mlx_lm/models/phi3.py
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache
+from .base import BaseModelArgs, KVCache, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
@@ -172,10 +172,7 @@ class Phi3Model(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py
index e0f2d856..de075652 100644
--- a/llms/mlx_lm/models/phi3small.py
+++ b/llms/mlx_lm/models/phi3small.py
@@ -6,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache
+from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@@ -263,10 +263,7 @@ class Phi3Model(nn.Module):
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py
index 40a3bc4b..f0aef0c9 100644
--- a/llms/mlx_lm/models/phixtral.py
+++ b/llms/mlx_lm/models/phixtral.py
@@ -6,6 +6,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
+from .base import create_attention_mask
from .switch_layers import SwitchMLP
@@ -167,10 +168,7 @@ class Model(nn.Module):
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
- mask = None
- if x.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
- mask = mask.astype(x.dtype)
+ mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)
diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py
index 2d0ddaed..47a9ea4f 100644
--- a/llms/mlx_lm/models/plamo.py
+++ b/llms/mlx_lm/models/plamo.py
@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -171,10 +171,7 @@ class PlamoModel(nn.Module):
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(self.embed_tokens.weight.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py
index 44b6dfd3..67816599 100644
--- a/llms/mlx_lm/models/qwen.py
+++ b/llms/mlx_lm/models/qwen.py
@@ -4,7 +4,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -122,11 +122,7 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
- mask = None
- T = x.shape[1]
- if T > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
- mask = mask.astype(x.dtype)
+ mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)
diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py
index fab09003..cb8268aa 100644
--- a/llms/mlx_lm/models/qwen2.py
+++ b/llms/mlx_lm/models/qwen2.py
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache
+from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@@ -151,10 +151,7 @@ class Qwen2Model(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py
index 57f154a0..121ab813 100644
--- a/llms/mlx_lm/models/qwen2_moe.py
+++ b/llms/mlx_lm/models/qwen2_moe.py
@@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache
+from .base import BaseModelArgs, KVCache, create_attention_mask
from .switch_layers import SwitchGLU
@@ -189,10 +189,7 @@ class Qwen2MoeModel(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py
index 30e3a332..9b4d043c 100644
--- a/llms/mlx_lm/models/stablelm.py
+++ b/llms/mlx_lm/models/stablelm.py
@@ -5,7 +5,7 @@ from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -198,11 +198,7 @@ class Model(nn.Module):
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
- mask = None
- if x.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
- mask = mask.astype(x.dtype)
-
+ mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py
index 7b058d8f..a6eb5377 100644
--- a/llms/mlx_lm/models/starcoder2.py
+++ b/llms/mlx_lm/models/starcoder2.py
@@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache
+from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
@@ -127,10 +127,7 @@ class Starcoder2Model(nn.Module):
):
h = self.embed_tokens(inputs)
- mask = None
- if h.shape[1] > 1:
- mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
- mask = mask.astype(h.dtype)
+ mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
From 85dc76f6e0f2cf3ee3d84c211868a6856e163f3f Mon Sep 17 00:00:00 2001
From: madroid
Date: Fri, 26 Jul 2024 23:58:52 +0800
Subject: [PATCH 003/188] Server: support stream_options (#913)
* Server: support stream_options
see https://x.com/OpenAIDevs/status/1787573348496773423
* Server: support stream_options
* Server: check None type
---
llms/mlx_lm/server.py | 31 ++++++++++++++++++++++++++-----
1 file changed, 26 insertions(+), 5 deletions(-)
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index 23d327e5..c13878f3 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -6,16 +6,12 @@ import logging
import time
import uuid
import warnings
-from functools import lru_cache
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
-from typing import Dict, List, Literal, NamedTuple, Optional, Tuple, Union
+from typing import Dict, List, Literal, NamedTuple, Optional, Union
import mlx.core as mx
-import mlx.nn as nn
-from transformers import PreTrainedTokenizer
-from .tokenizer_utils import TokenizerWrapper
from .utils import generate_step, load
@@ -195,6 +191,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Extract request parameters from the body
self.stream = self.body.get("stream", False)
+ self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0)
@@ -525,9 +522,33 @@ class APIHandler(BaseHTTPRequestHandler):
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
+ if self.stream_options is not None and self.stream_options["include_usage"]:
+ response = self.completion_usage_response(len(prompt), len(tokens))
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
+ def completion_usage_response(
+ self,
+ prompt_token_count: Optional[int] = None,
+ completion_token_count: Optional[int] = None,
+ ):
+ response = {
+ "id": self.request_id,
+ "system_fingerprint": f"fp_{uuid.uuid4()}",
+ "object": "chat.completion",
+ "model": self.requested_model,
+ "created": self.created,
+ "choices": [],
+ "usage": {
+ "prompt_tokens": prompt_token_count,
+ "completion_tokens": completion_token_count,
+ "total_tokens": prompt_token_count + completion_token_count,
+ },
+ }
+ return response
+
def handle_chat_completions(self) -> mx.array:
"""
Handle a chat completion request.
From 8fa12b0058a2647dff5d776f84deef43e1cb7720 Mon Sep 17 00:00:00 2001
From: Khush Gupta <78624519+khushgx@users.noreply.github.com>
Date: Thu, 1 Aug 2024 16:18:18 -0700
Subject: [PATCH 004/188] Adapters loading (#902)
* Added functionality to load in adapters through post-requests so you do not need to restart the server
* ran pre-commit
* nits
* fix test
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/SERVER.md | 7 +++++++
llms/mlx_lm/server.py | 22 ++++++++++++++++------
llms/tests/test_server.py | 2 +-
3 files changed, 24 insertions(+), 7 deletions(-)
diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md
index 48364bee..9c42d410 100644
--- a/llms/mlx_lm/SERVER.md
+++ b/llms/mlx_lm/SERVER.md
@@ -78,3 +78,10 @@ curl localhost:8080/v1/chat/completions \
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.
+
+- `model`: (Optional) A string path to a local model or Hugging Face repo id.
+ If the path is local is must be relative to the directory the server was
+ started in.
+
+- `adapters`: (Optional) A string path to low-rank adapters. The path must be
+ rlative to the directory the server was started in.
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index c13878f3..7456399c 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -97,8 +97,9 @@ class ModelProvider:
"Local models must be relative to the current working dir."
)
- def load(self, model_path):
- if self.model_key == model_path:
+ # Added in adapter_path to load dynamically
+ def load(self, model_path, adapter_path=None):
+ if self.model_key == (model_path, adapter_path):
return self.model, self.tokenizer
# Remove the old model if it exists.
@@ -116,18 +117,22 @@ class ModelProvider:
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
- adapter_path=self.cli_args.adapter_path,
+ adapter_path=(
+ adapter_path if adapter_path else self.cli_args.adapter_path
+ ), # if the user doesn't change the model but adds an adapter path
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
- model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
+ model, tokenizer = load(
+ model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
+ )
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
- self.model_key = model_path
+ self.model_key = (model_path, adapter_path)
self.model = model
self.tokenizer = tokenizer
@@ -193,6 +198,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
+ self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
@@ -204,7 +210,9 @@ class APIHandler(BaseHTTPRequestHandler):
# Load the model if needed
try:
- self.model, self.tokenizer = self.model_provider.load(self.requested_model)
+ self.model, self.tokenizer = self.model_provider.load(
+ self.requested_model, self.adapter
+ )
except:
self._set_completion_headers(404)
self.end_headers()
@@ -278,6 +286,8 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.requested_model, str):
raise ValueError("model must be a string")
+ if self.adapter is not None and not isinstance(self.adapter, str):
+ raise ValueError("adapter must be a string")
def generate_response(
self,
diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py
index 4d71a5a3..b8047eaa 100644
--- a/llms/tests/test_server.py
+++ b/llms/tests/test_server.py
@@ -12,7 +12,7 @@ class DummyModelProvider:
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
- def load(self, model):
+ def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]
return self.model, self.tokenizer
From df744c98e67f94fba8a893cb5aba4fdb735f5f4a Mon Sep 17 00:00:00 2001
From: tidely <43219534+tidely@users.noreply.github.com>
Date: Wed, 7 Aug 2024 01:24:15 +0300
Subject: [PATCH 005/188] Predict stop sequence matches during streaming (#541)
* Predict stop sequence matches during streaming
Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met.
* fix typo
* Change sequence_overlap logic
* range isn't inclusive, add 1 to max_overlap
* Add test_server.py
Added a test for the sequence_overlap method
* nits
* eos sequence
* finalize
---------
Co-authored-by: Y4hL <43219534+Y4hL@users.noreply.github.com>
Co-authored-by: Awni Hannun
---
llms/mlx_lm/server.py | 56 ++++++++++++++++++++++++---------------
llms/tests/test_server.py | 13 +++++++++
2 files changed, 47 insertions(+), 22 deletions(-)
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index 7456399c..79ac1836 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -8,7 +8,7 @@ import uuid
import warnings
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
-from typing import Dict, List, Literal, NamedTuple, Optional, Union
+from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
@@ -54,6 +54,21 @@ def stopping_criteria(
return StopCondition(stop_met=False, trim_length=0)
+def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
+ """
+ Checks if a suffix of s1 has overlap with a prefix of s2
+
+ Args:
+ s1 (Sequence): The first sequence
+ s2 (Sequence): The second sequence
+
+ Returns:
+ bool: If the two sequences have overlap
+ """
+ max_overlap = min(len(s1), len(s2))
+ return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
+
+
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": (
@@ -462,12 +477,13 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences: List[List[int]],
):
"""
- Generate response to prompt and foward it to the client using a Server Sent Events (SSE) stream.
+ Generate response to prompt and foward it to the client using a Server
+ Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
- stop_id_sequences (List[List[int]]):
- A list of stop words passed to the stopping_criteria function
+ stop_id_sequences (List[List[int]]): A list of stop words passed to
+ the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
@@ -476,12 +492,9 @@ class APIHandler(BaseHTTPRequestHandler):
detokenizer.reset()
tokens = []
- max_stop_id_sequence_len = len(max(stop_id_sequences, default=[]))
- # Buffer to store the last `max_stop_id_sequence_len` tokens
- # to check for stop conditions before writing to the stream.
- stop_sequence_buffer = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
+
for (token, _), _ in zip(
generate_step(
prompt=prompt,
@@ -496,11 +509,6 @@ class APIHandler(BaseHTTPRequestHandler):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
- stop_sequence_buffer.append(token)
-
- # Continue generating tokens until buffer is as large as the longest stop_id_sequence
- if len(stop_sequence_buffer) < max_stop_id_sequence_len:
- continue
stop_condition = stopping_criteria(
tokens,
@@ -514,21 +522,25 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
+ # If the end of tokens overlaps with a stop sequence, generate new
+ # tokens until we know if the stop sequence is hit or not
+ if any(
+ (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
+ ):
+ continue
+
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
- stop_sequence_buffer = []
# check is there any remaining text to send
- if stop_sequence_buffer:
- next_chunk = (
- detokenizer.last_segment
- if stop_sequence_suffix is None
- else detokenizer.last_segment[: -len(stop_sequence_suffix)]
- )
- response = self.generate_response(next_chunk, "length")
-
+ detokenizer.finalize()
+ last_segment = detokenizer.last_segment
+ if last_segment:
+ if stop_sequence_suffix is not None:
+ last_segment = last_segment[: -len(stop_sequence_suffix)]
+ response = self.generate_response(last_segment, "length")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py
index b8047eaa..baea664a 100644
--- a/llms/tests/test_server.py
+++ b/llms/tests/test_server.py
@@ -1,3 +1,4 @@
+# Copyright © 2024 Apple Inc.
import http
import threading
import unittest
@@ -76,6 +77,18 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
+ def test_sequence_overlap(self):
+ from mlx_lm.server import sequence_overlap
+
+ self.assertTrue(sequence_overlap([1], [1]))
+ self.assertTrue(sequence_overlap([1, 2], [1, 2]))
+ self.assertTrue(sequence_overlap([1, 3], [3, 4]))
+ self.assertTrue(sequence_overlap([1, 2, 3], [2, 3]))
+
+ self.assertFalse(sequence_overlap([1], [2]))
+ self.assertFalse(sequence_overlap([1, 2], [3, 4]))
+ self.assertFalse(sequence_overlap([1, 2, 3], [4, 1, 2, 3]))
+
if __name__ == "__main__":
unittest.main()
From 33905447f92dc7ce38680f7f6bd8d3335dd204e5 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Fri, 9 Aug 2024 11:11:58 -0700
Subject: [PATCH 006/188] Whisper updates to allow HF models (#923)
* simplify conversion and update convert for HF models
* use npz for compat
* fixes
* fixes
* fix gguf
* allow user supplied path
---
README.md | 1 +
llms/mlx_lm/gguf.py | 6 +-
whisper/convert.py | 176 ++++++++++++++++++++-------------
whisper/mlx_whisper/version.py | 2 +-
whisper/test.py | 6 +-
5 files changed, 116 insertions(+), 75 deletions(-)
diff --git a/README.md b/README.md
index 2ca11d4b..00e57803 100644
--- a/README.md
+++ b/README.md
@@ -32,6 +32,7 @@ Some more useful examples are listed below.
- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).
+- Image segmentation with [Segment Anything (SAM)](segment_anything).
### Other Models
diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py
index 1f858d70..5d524580 100644
--- a/llms/mlx_lm/gguf.py
+++ b/llms/mlx_lm/gguf.py
@@ -59,7 +59,7 @@ class HfVocab:
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
- token_text = reverse_vocab[token_id].encode("utf-8")
+ token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)
@@ -67,7 +67,7 @@ class HfVocab:
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
- if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
+ if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
@@ -84,7 +84,7 @@ class HfVocab:
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
- yield text.encode("utf-8"), score, toktype
+ yield text, score, toktype
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
diff --git a/whisper/convert.py b/whisper/convert.py
index 37825d6c..85ce5fba 100644
--- a/whisper/convert.py
+++ b/whisper/convert.py
@@ -105,6 +105,88 @@ def available_models() -> List[str]:
return list(_MODELS.keys())
+def hf_to_pt(weights, config):
+ config = {
+ "n_mels": config["num_mel_bins"],
+ "n_audio_ctx": config["max_source_positions"],
+ "n_audio_state": config["d_model"],
+ "n_audio_head": config["encoder_attention_heads"],
+ "n_audio_layer": config["encoder_layers"],
+ "n_vocab": config["vocab_size"],
+ "n_text_ctx": config["max_target_positions"],
+ "n_text_state": config["d_model"],
+ "n_text_head": config["decoder_attention_heads"],
+ "n_text_layer": config["decoder_layers"],
+ }
+
+ def remap(k):
+ k = k.replace("model.", "")
+ k = k.replace(".layers", ".blocks")
+ k = k.replace(".self_attn", ".attn")
+ k = k.replace(".attn_layer_norm", ".attn_ln")
+ k = k.replace(".encoder_attn.", ".cross_attn.")
+ k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln")
+ k = k.replace(".final_layer_norm", ".mlp_ln")
+ k = k.replace(".q_proj", ".query")
+ k = k.replace(".k_proj", ".key")
+ k = k.replace(".v_proj", ".value")
+ k = k.replace(".out_proj", ".out")
+ k = k.replace(".fc1", ".mlp1")
+ k = k.replace(".fc2", ".mlp2")
+ k = k.replace("embed_positions.weight", "positional_embedding")
+ k = k.replace("decoder.embed_tokens", "decoder.token_embedding")
+ k = k.replace("encoder.layer_norm", "encoder.ln_post")
+ k = k.replace("decoder.layer_norm", "decoder.ln")
+ return k
+
+ # token embeddings are shared with output projection
+ weights.pop("proj_out.weight", None)
+ weights = {remap(k): v for k, v in weights.items()}
+ return weights, config
+
+
+def load_torch_weights_and_config(
+ name_or_path: str,
+ download_root: str = None,
+):
+ if download_root is None:
+ download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
+
+ # todo: accept alignment_heads of local Pytorch checkpoint
+ alignment_heads = None
+ if name_or_path in _MODELS:
+ alignment_heads = _ALIGNMENT_HEADS[name_or_path]
+ name_or_path = _download(_MODELS[name_or_path], download_root)
+ elif not Path(name_or_path).exists():
+ # Try downloading from HF
+ from huggingface_hub import snapshot_download
+
+ name_or_path = snapshot_download(
+ repo_id=name_or_path,
+ allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
+ )
+ else:
+ raise RuntimeError(
+ f"Model {name_or_path} is not found in {available_models()},"
+ "on Hugging Face or as a local path."
+ )
+
+ if name_or_path.endswith(".pt"):
+ checkpoint = torch.load(name_or_path, map_location="cpu")
+ weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
+ else:
+ name_or_path = Path(name_or_path)
+ weights = torch.load(
+ name_or_path / "pytorch_model.bin",
+ map_location="cpu",
+ )
+ with open(name_or_path / "config.json", "r") as fp:
+ config = json.load(fp)
+ weights, config = hf_to_pt(weights, config)
+
+ return weights, config, alignment_heads
+
+
def load_torch_model(
name_or_path: str,
download_root: str = None,
@@ -115,7 +197,8 @@ def load_torch_model(
Parameters
----------
name_or_path : str
- one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format
+ one of the official model names listed by `whisper.available_models()` or
+ a local Pytorch checkpoint which is in the original OpenAI format
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
@@ -128,22 +211,12 @@ def load_torch_model(
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
- # todo: accept alignment_heads of local Pytorch checkpoint
- alignment_heads = None
- if name_or_path in _MODELS:
- alignment_heads = _ALIGNMENT_HEADS[name_or_path]
- name_or_path = _download(_MODELS[name_or_path], download_root)
- elif not Path(name_or_path).is_file():
- raise RuntimeError(
- f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
- )
-
- with open(name_or_path, "rb") as fp:
- checkpoint = torch.load(fp)
-
- dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
+ weights, config, alignment_heads = load_torch_weights_and_config(
+ name_or_path, download_root
+ )
+ dims = torch_whisper.ModelDimensions(**config)
model = torch_whisper.Whisper(dims)
- model.load_state_dict(checkpoint["model_state_dict"])
+ model.load_state_dict(weights)
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
@@ -151,59 +224,26 @@ def load_torch_model(
return model
-def convert(model, rules=None):
- params = {}
- if rules is not None and type(model) in rules:
- out = rules[type(model)](model, rules)
- return out
- if isinstance(model, torch.Tensor):
- return mx.array(model.detach().numpy())
- if isinstance(model, torch.nn.ModuleList):
- return [convert(n, rules) for n in model.children()]
- if isinstance(model, torch.nn.Conv1d):
- return {
- "weight": convert(model.weight).transpose(0, 2, 1),
- "bias": convert(model.bias),
- }
- for k, n in model.named_children():
- if k in rules:
- params.update(rules[k](n, rules))
- else:
- params[k] = convert(n, rules)
- for k, p in model.named_parameters(recurse=False):
- params[k] = convert(p)
- return params
+def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
+ def remap(key, value):
+ key = key.replace("mlp.0", "mlp1")
+ key = key.replace("mlp.2", "mlp2")
+ if "conv" in key and value.ndim == 3:
+ value = value.swapaxes(1, 2)
+ return key, mx.array(value.detach()).astype(dtype)
+ weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
+ weights.pop("encoder.positional_embedding", None)
+ weights = dict(remap(k, v) for k, v in weights.items())
-def torch_to_mlx(
- torch_model: torch_whisper.Whisper,
- dtype: mx.Dtype = mx.float16,
-) -> Whisper:
- def convert_rblock(model, rules):
- children = dict(model.named_children())
- mlp = list(children.pop("mlp").children())
- params = {
- "mlp1": convert(mlp[0], rules),
- "mlp2": convert(mlp[-1], rules),
- }
- for k, n in children.items():
- params[k] = convert(n, rules)
- return params
+ model_dims = ModelDimensions(**config)
+ model = Whisper(model_dims, dtype)
+ model.load_weights(list(weights.items()), strict=False)
- rules = {
- torch_whisper.ResidualAttentionBlock: convert_rblock,
- }
+ if alignment_heads is not None:
+ model.set_alignment_heads(alignment_heads)
- params = convert(torch_model, rules)
-
- mlx_model = Whisper(torch_model.dims, dtype)
- params = tree_map(lambda p: p.astype(dtype), params)
- mlx_model.update(params)
-
- if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
- mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
-
- return mlx_model
+ return model
def upload_to_hub(path: str, name: str, torch_name_or_path: str):
@@ -292,13 +332,13 @@ if __name__ == "__main__":
action="store_true",
)
parser.add_argument(
- "--q_group_size",
+ "--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
- "--q_bits",
+ "--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
@@ -318,7 +358,7 @@ if __name__ == "__main__":
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")
- model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
+ model = convert(args.torch_name_or_path, dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))
diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py
index 87ee07a7..ae3cfb71 100644
--- a/whisper/mlx_whisper/version.py
+++ b/whisper/mlx_whisper/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.1.0"
+__version__ = "0.2.0"
diff --git a/whisper/test.py b/whisper/test.py
index ce559251..f0acb3cd 100644
--- a/whisper/test.py
+++ b/whisper/test.py
@@ -13,7 +13,7 @@ import mlx_whisper.decoding as decoding
import mlx_whisper.load_models as load_models
import numpy as np
import torch
-from convert import load_torch_model, quantize, torch_to_mlx
+from convert import convert, load_torch_model, quantize
from mlx.utils import tree_flatten
MODEL_NAME = "tiny"
@@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config):
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)
- fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
+ fp32_model = convert(MODEL_NAME, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)
- fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
+ fp16_model = convert(MODEL_NAME, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)
From 95840f32e28ca2af4ec49096fc4ddace51e0fb30 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Wed, 14 Aug 2024 10:22:04 -0700
Subject: [PATCH 007/188] Fix whipser conversion for safetensors models (#935)
* fix whipser conversion for safetensor only. error in mlx lm for existing paths
* fix tests
---
llms/mlx_lm/utils.py | 13 ++++++++++---
llms/tests/test_utils.py | 1 +
whisper/convert.py | 34 +++++++++++++++++++++++-----------
3 files changed, 34 insertions(+), 14 deletions(-)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index cffa2a89..0e7f7a39 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -660,6 +660,16 @@ def convert(
revision: Optional[str] = None,
dequantize: bool = False,
):
+ # Check the save path is empty
+ if isinstance(mlx_path, str):
+ mlx_path = Path(mlx_path)
+
+ if mlx_path.exists():
+ raise ValueError(
+ f"Cannot save to the path {mlx_path} as it already exists."
+ " Please delete the file/directory or specify a new path to save to."
+ )
+
print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
@@ -681,9 +691,6 @@ def convert(
model = dequantize_model(model)
weights = dict(tree_flatten(model.parameters()))
- if isinstance(mlx_path, str):
- mlx_path = Path(mlx_path)
-
del model
save_weights(mlx_path, weights, donate_weights=True)
diff --git a/llms/tests/test_utils.py b/llms/tests/test_utils.py
index 576c2820..18cfa8c7 100644
--- a/llms/tests/test_utils.py
+++ b/llms/tests/test_utils.py
@@ -82,6 +82,7 @@ class TestUtils(unittest.TestCase):
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))
# Check model weights have right type
+ mlx_path = os.path.join(self.test_dir, "mlx_model_bf16")
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
model, _ = utils.load(mlx_path)
diff --git a/whisper/convert.py b/whisper/convert.py
index 85ce5fba..da7195e0 100644
--- a/whisper/convert.py
+++ b/whisper/convert.py
@@ -163,7 +163,12 @@ def load_torch_weights_and_config(
name_or_path = snapshot_download(
repo_id=name_or_path,
- allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
+ allow_patterns=[
+ "*.json",
+ "pytorch_model.bin",
+ "model.safetensors",
+ "*.txt",
+ ],
)
else:
raise RuntimeError(
@@ -176,10 +181,11 @@ def load_torch_weights_and_config(
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
- weights = torch.load(
- name_or_path / "pytorch_model.bin",
- map_location="cpu",
- )
+ pt_path = name_or_path / "pytorch_model.bin"
+ if pt_path.is_file():
+ weights = torch.load(pt_path, map_location="cpu")
+ else:
+ weights = mx.load(str(name_or_path / "model.safetensors"))
with open(name_or_path / "config.json", "r") as fp:
config = json.load(fp)
weights, config = hf_to_pt(weights, config)
@@ -230,7 +236,9 @@ def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
key = key.replace("mlp.2", "mlp2")
if "conv" in key and value.ndim == 3:
value = value.swapaxes(1, 2)
- return key, mx.array(value.detach()).astype(dtype)
+ if isinstance(value, torch.Tensor):
+ value = mx.array(value.detach())
+ return key, value.astype(dtype)
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
weights.pop("encoder.positional_embedding", None)
@@ -262,12 +270,16 @@ This model was converted to MLX format from [`{torch_name_or_path}`]().
## Use with mlx
```bash
-git clone https://github.com/ml-explore/mlx-examples.git
-cd mlx-examples/whisper/
-pip install -r requirements.txt
+pip install mlx-whisper
+```
->> import whisper
->> whisper.transcribe("FILE_NAME")
+```python
+import mlx_whisper
+
+result = mlx_whisper.transcribe(
+ "FILE_NAME",
+ path_or_hf_repo={repo_id},
+)
```
"""
card = ModelCard(text)
From 9b83004631bf0d2c225d4bc31c0ed613b1b15b8f Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Thu, 15 Aug 2024 11:29:09 -0700
Subject: [PATCH 008/188] Faster sampling with `mx.compile` (#937)
* faster sampling with compile
* fix test
---
llms/mlx_lm/sample_utils.py | 14 +++++++++--
llms/mlx_lm/utils.py | 4 ++--
llms/tests/test_sample_utils.py | 42 ++++++++++++++-------------------
3 files changed, 32 insertions(+), 28 deletions(-)
diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py
index f22ce2d7..2e9c172e 100644
--- a/llms/mlx_lm/sample_utils.py
+++ b/llms/mlx_lm/sample_utils.py
@@ -1,6 +1,11 @@
+# Copyright © 2023-2024 Apple Inc.
+
+from functools import partial
+
import mlx.core as mx
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
@@ -13,7 +18,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
- probs = mx.softmax(logits / temperature, axis=-1)
+ probs = mx.softmax(logits * (1 / temperature), axis=-1)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
@@ -25,10 +30,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
- mx.zeros_like(sorted_probs),
+ 0,
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def categorical_sampling(logits, temp):
+ return mx.random.categorical(logits * (1 / temp))
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 0e7f7a39..a34cc6ad 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache
-from .sample_utils import top_p_sampling
+from .sample_utils import categorical_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
@@ -169,7 +169,7 @@ def generate_step(
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
else:
- token = mx.random.categorical(logits * (1 / temp))
+ token = categorical_sampling(logits, temp)
return token, logprobs
diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py
index 0bccdd07..ec0e2cb7 100644
--- a/llms/tests/test_sample_utils.py
+++ b/llms/tests/test_sample_utils.py
@@ -1,38 +1,32 @@
import unittest
-from unittest.mock import patch
import mlx.core as mx
from mlx_lm.sample_utils import top_p_sampling
class TestSamplingUtils(unittest.TestCase):
- @patch("mlx.core.random.categorical")
- def test_top_p_sampling(self, mock_categorical):
- logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
- top_p = 0.3
+ def test_top_p_sampling(self):
+ probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
+ logits = mx.log(probs)
temperature = 1.0
- expected_token = mx.array([3])
- mock_categorical.return_value = expected_token
- token = top_p_sampling(logits, top_p, temperature)
- expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]])
- self.assertTrue(mx.allclose(token, expected_token))
- args, _ = mock_categorical.call_args
- self.assertTrue(args[0].shape == expected_top_probs.shape)
- self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
+ token = top_p_sampling(logits, 0.3, temperature).item()
+ self.assertEqual(token, 0)
- logits = mx.array([[1.0, 2.0, 3.0, 4.0]])
- top_p = 0.9
- temperature = 1.0
- expected_token = mx.array([3])
- mock_categorical.return_value = expected_token
+ token = top_p_sampling(logits, 0.95, temperature).item()
+ self.assertTrue(token in (0, 3))
- token = top_p_sampling(logits, top_p, temperature)
- expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]])
- self.assertTrue(mx.allclose(token, expected_token))
- args, _ = mock_categorical.call_args
- self.assertTrue(args[0].shape == expected_top_probs.shape)
- self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs)))
+ probs = mx.array([0.0, 0.5, 0.4, 0.1])[None]
+ logits = mx.log(probs)
+
+ token = top_p_sampling(logits, 0.4, temperature).item()
+ self.assertEqual(token, 1)
+
+ token = top_p_sampling(logits, 0.6, temperature).item()
+ self.assertTrue(token in (1, 2))
+
+ token = top_p_sampling(logits, 0.95, temperature).item()
+ self.assertTrue(token in (1, 2, 3))
if __name__ == "__main__":
From c50971e8608e05c42d8078b7009dc5b9c33d25c0 Mon Sep 17 00:00:00 2001
From: Chime Ogbuji
Date: Thu, 15 Aug 2024 18:45:02 -0400
Subject: [PATCH 009/188] Min P implementation (#926)
* Min P implementation
* Change default to 0 (no min_p)
* nits
* nits
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/sample_utils.py | 58 +++++++++++++++++++++++++++++++++++++
llms/mlx_lm/utils.py | 10 ++++++-
2 files changed, 67 insertions(+), 1 deletion(-)
diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py
index 2e9c172e..20b008fa 100644
--- a/llms/mlx_lm/sample_utils.py
+++ b/llms/mlx_lm/sample_utils.py
@@ -5,6 +5,64 @@ from functools import partial
import mlx.core as mx
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def min_p_sampling(
+ logits: mx.array,
+ min_p: float,
+ min_tokens_to_keep: int = 1,
+ temperature=1.0,
+) -> mx.array:
+ """
+ Apply min-p sampling to the logits.
+
+ Min-p keeps all tokens that are above a minimum probability, scaled by the
+ probability of the most likely token. As a result, the filter is more
+ aggressive given a very high-probability token.
+
+ Args:
+ logits: The logits from the model's output.
+ min_p (float): Minimum token probability. Typical values are in the
+ 0.01-0.2 range, comparably selective as setting `top_p` in the
+ 0.99-0.8 range.
+ min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
+ be filtered. Default: ``1``.
+
+ """
+ if not (0 <= min_p <= 1.0):
+ raise ValueError(
+ f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}"
+ )
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
+ raise ValueError(
+ f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
+ )
+ # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
+
+ # Softmax probabilities
+ probs = mx.softmax(logits * (1 / temperature), axis=-1)
+
+ # Indices sorted in decreasing order
+ sorted_indices = mx.argsort(-logits).squeeze(0)
+ sorted_probs = probs[..., sorted_indices]
+
+ # Top probability
+ top_probs = probs[..., sorted_indices[0]]
+
+ # Calculate the min_p threshold
+ scaled_min_p = min_p * top_probs
+
+ # Mask tokens that have a probability less than the scaled min_p
+ tokens_to_remove = sorted_probs < scaled_min_p
+ tokens_to_remove[..., :min_tokens_to_keep] = False
+
+ # Create pool of tokens with probability less than scaled min_p
+ selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
+
+ # Return sampled token
+ sorted_token = mx.random.categorical(mx.log(selected_probs))
+ return sorted_indices[sorted_token]
+
+
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
"""
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index a34cc6ad..e7a9dba8 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache
-from .sample_utils import categorical_sampling, top_p_sampling
+from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
@@ -133,6 +133,8 @@ def generate_step(
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
+ min_p: float = 0.0,
+ min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
@@ -149,6 +151,10 @@ def generate_step(
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
+ min_p (float, optional): The minimum value (scaled by the top token's
+ probability) that a token probability must have to be considered.
+ min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
+ be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
Yields:
@@ -168,6 +174,8 @@ def generate_step(
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
+ elif min_p != 0.0:
+ token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
From 4e017008169d31fe943b8ca6023f058a724e555f Mon Sep 17 00:00:00 2001
From: Zai Thottakath
Date: Fri, 16 Aug 2024 09:38:36 -0500
Subject: [PATCH 010/188] Allow the entire model to be targed for LoRA and DoRA
fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)
* feature: LoRA adapter for Embeddings
* feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning
* feature: DoRA adapter for Embeddings
* feature: wire in DoRAEmbedding
* bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear
* refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding
* cleanup: remove unused imports in test_dora.py
* refactor: remove unnecessary non_layer_modules
* cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout
* nits
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/fuse.py | 10 +-
llms/mlx_lm/tuner/dora.py | 110 +++++++++++++++++-
llms/mlx_lm/tuner/lora.py | 102 +++++++++++++++-
llms/mlx_lm/tuner/utils.py | 15 ++-
llms/tests/{test_lora.py => test_finetune.py} | 90 +++++++++++++-
5 files changed, 306 insertions(+), 21 deletions(-)
rename llms/tests/{test_lora.py => test_finetune.py} (65%)
diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py
index fa06eb54..16457036 100644
--- a/llms/mlx_lm/fuse.py
+++ b/llms/mlx_lm/fuse.py
@@ -6,8 +6,8 @@ from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
-from .tuner.dora import DoRALinear
-from .tuner.lora import LoRALinear, LoRASwitchLinear
+from .tuner.dora import DoRAEmbedding, DoRALinear
+from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .utils import (
fetch_from_hub,
@@ -80,9 +80,11 @@ def main() -> None:
model = apply_lora_layers(model, args.adapter_path)
fused_linears = [
- (n, m.to_linear())
+ (n, m.fuse())
for n, m in model.named_modules()
- if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
+ if isinstance(
+ m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
+ )
]
model.update_modules(tree_unflatten(fused_linears))
diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py
index de10556b..bd2dfb01 100644
--- a/llms/mlx_lm/tuner/dora.py
+++ b/llms/mlx_lm/tuner/dora.py
@@ -8,7 +8,7 @@ import mlx.nn as nn
class DoRALinear(nn.Module):
@staticmethod
- def from_linear(
+ def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
@@ -25,10 +25,10 @@ class DoRALinear(nn.Module):
dropout=dropout,
scale=scale,
)
- dora_lin.linear = linear
+ dora_lin.set_linear(linear)
return dora_lin
- def to_linear(self, de_quantize: bool = False):
+ def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -61,7 +61,7 @@ class DoRALinear(nn.Module):
super().__init__()
# Regular linear layer weights
- self.linear = nn.Linear(input_dims, output_dims, bias=bias)
+ self.set_linear(nn.Linear(input_dims, output_dims, bias=bias))
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
@@ -75,6 +75,9 @@ class DoRALinear(nn.Module):
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
+
+ def set_linear(self, linear: nn.Linear):
+ self.linear = linear
self.m = mx.linalg.norm(self.linear.weight, axis=1)
def __call__(self, x):
@@ -93,3 +96,102 @@ class DoRALinear(nn.Module):
if "bias" in self.linear:
out = out + self.linear.bias
return out
+
+
+class DoRAEmbedding(nn.Module):
+ def from_base(
+ embedding: nn.Embedding,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 20.0,
+ ):
+ num_embeddings, dims = embedding.weight.shape
+
+ # TODO support quantized weights in DoRALinear
+ if isinstance(embedding, nn.QuantizedLinear):
+ raise ValueError("DoRAEmbedding does not yet support quantization.")
+ dora_embedding = DoRAEmbedding(
+ num_embeddings=num_embeddings,
+ dims=dims,
+ r=r,
+ dropout=dropout,
+ scale=scale,
+ )
+ dora_embedding.set_embedding(embedding)
+ return dora_embedding
+
+ def fuse(self, de_quantize: bool = False):
+ embedding = self.embedding
+ weight = embedding.weight
+
+ # Use the same type as the linear weight if not quantized
+ dtype = weight.dtype
+
+ num_embeddings, dims = weight.shape
+ fused_embedding = nn.Embedding(num_embeddings, dims)
+
+ lora_a = (self.scale * self.lora_a).astype(dtype)
+ lora_b = self.lora_b.astype(dtype)
+ weight = weight + lora_a @ lora_b
+ norm_scale = self.m / mx.linalg.norm(weight, axis=1)
+ fused_embedding.weight = norm_scale[:, None] * weight
+
+ return fused_embedding
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ dims: int,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 20.0,
+ ):
+ super().__init__()
+
+ # Regular embedding layer weights
+ self.set_embedding(nn.Embedding(num_embeddings, dims))
+ self.dropout = nn.Dropout(p=dropout)
+
+ # Scale for low-rank update
+ self.scale = scale
+
+ # Low rank lora weights
+ scale = 1 / math.sqrt(num_embeddings)
+ self.lora_a = mx.random.uniform(
+ low=-scale,
+ high=scale,
+ shape=(num_embeddings, r),
+ )
+ self.lora_b = mx.zeros(shape=(r, dims))
+
+ def set_embedding(self, embedding: nn.Module):
+ self.embedding = embedding
+ self.m = mx.linalg.norm(embedding.weight, axis=1)
+
+ def __call__(self, x):
+ y = self.embedding(x)
+ z = self.scale * self.lora_a[x] @ self.lora_b
+ out = y + self.dropout(z).astype(y.dtype)
+
+ # Compute the norm of the adapted weights for the individual embeddings
+ adapted = y + z
+ denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1))
+
+ # Remove the norm and scale by the learned magnitude
+ out = (self.m[x] / denom)[..., None] * out
+
+ return out
+
+ def as_linear(self, x):
+ y = self.embedding.as_linear(x)
+ z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
+ out = y + (self.scale * z).astype(x.dtype)
+
+ # Compute the norm of the adapted weights
+ adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b
+ denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
+
+ # Remove the norm and scale by the learned magnitude
+ out = (self.m / denom) * out
+
+ return out
diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py
index 19babb0e..c788cb73 100644
--- a/llms/mlx_lm/tuner/lora.py
+++ b/llms/mlx_lm/tuner/lora.py
@@ -10,7 +10,7 @@ from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
class LoRALinear(nn.Module):
@staticmethod
- def from_linear(
+ def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
@@ -31,7 +31,7 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear
return lora_lin
- def to_linear(self, de_quantize: bool = False):
+ def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -41,7 +41,7 @@ class LoRALinear(nn.Module):
dtype = weight.dtype
if is_quantized:
- dtype = mx.float16
+ dtype = linear.scales.dtype
weight = mx.dequantize(
weight,
linear.scales,
@@ -103,7 +103,7 @@ class LoRALinear(nn.Module):
class LoRASwitchLinear(nn.Module):
@staticmethod
- def from_linear(
+ def from_base(
linear: nn.Module,
r: int = 8,
dropout: float = 0.0,
@@ -120,7 +120,7 @@ class LoRASwitchLinear(nn.Module):
lora_lin.linear = linear
return lora_lin
- def to_linear(self, de_quantize: bool = False):
+ def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -191,3 +191,95 @@ class LoRASwitchLinear(nn.Module):
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
return y + (self.scale * z).astype(x.dtype)
+
+
+class LoRAEmbedding(nn.Module):
+ @staticmethod
+ def from_base(
+ embedding: nn.Embedding,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 20.0,
+ ):
+ num_embeddings, dims = embedding.weight.shape
+ if isinstance(embedding, nn.QuantizedEmbedding):
+ dims *= 32 // embedding.bits
+ lora_embedding = LoRAEmbedding(
+ num_embeddings=num_embeddings,
+ dims=dims,
+ r=r,
+ dropout=dropout,
+ scale=scale,
+ )
+ lora_embedding.embedding = embedding
+ return lora_embedding
+
+ def fuse(self, de_quantize: bool = False):
+ embedding = self.embedding
+ weight = embedding.weight
+ is_quantized = isinstance(embedding, nn.QuantizedEmbedding)
+
+ # Use the same type as the linear weight if not quantized
+ dtype = weight.dtype
+
+ if is_quantized:
+ dtype = embedding.scales.dtype
+ weight = mx.dequantize(
+ weight,
+ embedding.scales,
+ embedding.biases,
+ embedding.group_size,
+ embedding.bits,
+ )
+ num_embeddings, dims = weight.shape
+ fused_embedding = nn.Embedding(num_embeddings, dims)
+
+ lora_a = (self.scale * self.lora_a).astype(dtype)
+ lora_b = self.lora_b.astype(dtype)
+ fused_embedding.weight = weight + lora_a @ lora_b
+
+ if is_quantized and not de_quantize:
+ fused_embedding = nn.QuantizedEmbedding.from_embedding(
+ fused_embedding,
+ embedding.group_size,
+ embedding.bits,
+ )
+
+ return fused_embedding
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ dims: int,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 20.0,
+ ):
+ super().__init__()
+
+ # Regular embedding layer
+ self.embedding = nn.Embedding(num_embeddings, dims)
+ self.dropout = nn.Dropout(p=dropout)
+
+ # Scale for low-rank update
+ self.scale = scale
+
+ # Low rank lora weights
+ scale = 1 / math.sqrt(num_embeddings)
+ self.lora_a = mx.random.uniform(
+ low=-scale,
+ high=scale,
+ shape=(num_embeddings, r),
+ )
+ self.lora_b = mx.zeros(shape=(r, dims))
+
+ def __call__(self, x):
+ y = self.embedding(x)
+ z = self.dropout(self.lora_a[x] @ self.lora_b)
+ out = y + (self.scale * z).astype(y.dtype)
+ return out
+
+ def as_linear(self, x):
+ y = self.embedding.as_linear(x)
+ z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
+ return y + (self.scale * z).astype(x.dtype)
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index 2c97228d..b7f1f9de 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -10,8 +10,8 @@ import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
-from .dora import DoRALinear
-from .lora import LoRALinear, LoRASwitchLinear
+from .dora import DoRAEmbedding, DoRALinear
+from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict):
@@ -71,12 +71,14 @@ def linear_to_lora_layers(
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
+ elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
+ LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
- return LoRALayer.from_linear(
+ return LoRALayer.from_base(
layer,
r=config["rank"],
scale=config["scale"],
@@ -130,7 +132,12 @@ def linear_to_lora_layers(
for l in model.layers[num_layers - num_lora_layers :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
- l.update_modules(tree_unflatten(lora_layers))
+ if lora_layers:
+ l.update_modules(tree_unflatten(lora_layers))
+
+ lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
+ if lora_modules:
+ model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
diff --git a/llms/tests/test_lora.py b/llms/tests/test_finetune.py
similarity index 65%
rename from llms/tests/test_lora.py
rename to llms/tests/test_finetune.py
index f37ae3c2..289b8cfb 100644
--- a/llms/tests/test_lora.py
+++ b/llms/tests/test_finetune.py
@@ -6,10 +6,13 @@ import unittest
from io import StringIO
from unittest.mock import MagicMock
+import mlx.core as mx
+import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
-from mlx_lm.tuner.lora import LoRALinear
+from mlx_lm.tuner.dora import DoRAEmbedding
+from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@@ -33,11 +36,12 @@ class TestLora(unittest.TestCase):
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
+ tie_word_embeddings=False,
)
lora_layers = 4
- def check_config(params):
+ def check_config(params, expected_trainable_parameters=None):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
@@ -47,9 +51,11 @@ class TestLora(unittest.TestCase):
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
- self.assertEqual(
- trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys
+
+ expected_trainable_parameters = expected_trainable_parameters or (
+ lora_layers * params["rank"] * args.hidden_size * 2 * n_keys
)
+ self.assertEqual(trainable_params, expected_trainable_parameters)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
@@ -60,6 +66,22 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"]
check_config(params)
+ params["keys"] = ["lm_head"]
+ check_config(
+ params,
+ expected_trainable_parameters=(
+ params["rank"] * (args.hidden_size + args.vocab_size)
+ ),
+ )
+
+ params["keys"] = ["model.embed_tokens"]
+ check_config(
+ params,
+ expected_trainable_parameters=(
+ params["rank"] * (args.hidden_size + args.vocab_size)
+ ),
+ )
+
def test_gpt_neox(self):
from mlx_lm.models import gpt_neox
@@ -82,6 +104,66 @@ class TestLora(unittest.TestCase):
model.freeze()
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
+ def test_lora_embedding(self):
+ num_embeddings = 256
+ dims = 512
+ tokens = mx.array([1, 2, 3])
+
+ embedding = nn.QuantizedEmbedding(num_embeddings, dims)
+ dequantized_weight = mx.dequantize(
+ embedding.weight,
+ embedding.scales,
+ embedding.biases,
+ embedding.group_size,
+ embedding.bits,
+ )
+ lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
+ new_embedding = lora_emb.fuse(de_quantize=True)
+ self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
+ self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
+
+ # as_linear
+ attn_output = mx.random.uniform(shape=(dims,))
+ embedding_lin_out = lora_emb.as_linear(attn_output)
+ self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
+ self.assertTrue(
+ mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
+ )
+
+ # change the value of lora_b and the embeddings will no longer be equal
+ lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
+ new_embedding = lora_emb.fuse(de_quantize=True)
+ self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
+ self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
+
+
+class TestDora(unittest.TestCase):
+ def test_dora_embedding(self):
+ num_embeddings = 256
+ dims = 512
+ tokens = mx.array([1, 2, 3])
+
+ embedding = nn.Embedding(num_embeddings, dims)
+
+ dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
+ new_embedding = dora_emb.fuse()
+ self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight))
+ self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens)))
+
+ # as_linear
+ attn_output = mx.random.uniform(shape=(dims,))
+ embedding_lin_out = dora_emb.as_linear(attn_output)
+ self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
+ self.assertTrue(
+ mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
+ )
+
+ # change the value of lora_b and the embeddings will no longer be equal
+ dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape)
+ new_embedding = dora_emb.fuse()
+ self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
+ self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
+
class TestScheduleConfig(unittest.TestCase):
def test_join(self):
From e196fa3208f032af1fa23374eab57169b4171c89 Mon Sep 17 00:00:00 2001
From: madroid
Date: Sat, 17 Aug 2024 01:35:44 +0800
Subject: [PATCH 011/188] Whisper: Support command line (#746)
* Whisper: Add CLI command
* Whisper: Prevent precision loss when converting to words dictionary
* Whisper: disable json ensure_ascii
* Whisper: add cli setup config
* Whisper: pre-commit
* Whisper: Adjust the _ in the command line arguments to -
* nits
* version + readme
* nit
---------
Co-authored-by: Awni Hannun
---
whisper/README.md | 16 ++
whisper/mlx_whisper/cli.py | 236 ++++++++++++++++++++++++++++
whisper/mlx_whisper/timing.py | 2 +-
whisper/mlx_whisper/version.py | 2 +-
whisper/mlx_whisper/writers.py | 272 +++++++++++++++++++++++++++++++++
whisper/setup.py | 5 +
6 files changed, 531 insertions(+), 2 deletions(-)
create mode 100644 whisper/mlx_whisper/cli.py
create mode 100644 whisper/mlx_whisper/writers.py
diff --git a/whisper/README.md b/whisper/README.md
index 2805e899..ac6e95f6 100644
--- a/whisper/README.md
+++ b/whisper/README.md
@@ -21,6 +21,22 @@ pip install mlx-whisper
### Run
+#### CLI
+
+At its simplest:
+
+```
+mlx_whisper audio_file.mp3
+```
+
+This will make a text file `audio_file.txt` with the results.
+
+Use `-f` to specify the output format and `--model` to specify the model. There
+are many other supported command line options. To see them all, run
+`mlx_whisper -h`.
+
+#### API
+
Transcribe audio with:
```python
diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py
new file mode 100644
index 00000000..c2813338
--- /dev/null
+++ b/whisper/mlx_whisper/cli.py
@@ -0,0 +1,236 @@
+# Copyright © 2024 Apple Inc.
+
+import argparse
+import os
+import traceback
+import warnings
+
+from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
+from .transcribe import transcribe
+from .writers import get_writer
+
+
+def build_parser():
+ def optional_int(string):
+ return None if string == "None" else int(string)
+
+ def optional_float(string):
+ return None if string == "None" else float(string)
+
+ def str2bool(string):
+ str2val = {"True": True, "False": False}
+ if string in str2val:
+ return str2val[string]
+ else:
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "audio", nargs="+", type=str, help="Audio file(s) to transcribe"
+ )
+ parser.add_argument(
+ "--model",
+ default="mlx-community/whisper-tiny",
+ type=str,
+ help="The model directory or hugging face repo",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ type=str,
+ default=".",
+ help="Directory to save the outputs",
+ )
+ parser.add_argument(
+ "--output-format",
+ "-f",
+ type=str,
+ default="txt",
+ choices=["txt", "vtt", "srt", "tsv", "json", "all"],
+ help="Format of the output file",
+ )
+ parser.add_argument(
+ "--verbose",
+ type=str2bool,
+ default=True,
+ help="Whether to print out progress and debug messages",
+ )
+ parser.add_argument(
+ "--task",
+ type=str,
+ default="transcribe",
+ choices=["transcribe", "translate"],
+ help="Perform speech recognition ('transcribe') or speech translation ('translate')",
+ )
+ parser.add_argument(
+ "--language",
+ type=str,
+ default=None,
+ choices=sorted(LANGUAGES.keys())
+ + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
+ help="Language spoken in the audio, specify None to auto-detect",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0, help="Temperature for sampling"
+ )
+ parser.add_argument(
+ "--best-of",
+ type=optional_int,
+ default=5,
+ help="Number of candidates when sampling with non-zero temperature",
+ )
+ parser.add_argument(
+ "--patience",
+ type=float,
+ default=None,
+ help="Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search",
+ )
+ parser.add_argument(
+ "--length-penalty",
+ type=float,
+ default=None,
+ help="Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.",
+ )
+ parser.add_argument(
+ "--suppress-tokens",
+ type=str,
+ default="-1",
+ help="Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations",
+ )
+ parser.add_argument(
+ "--initial-prompt",
+ type=str,
+ default=None,
+ help="Optional text to provide as a prompt for the first window.",
+ )
+ parser.add_argument(
+ "--condition-on-previous-text",
+ type=str2bool,
+ default=True,
+ help="If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop",
+ )
+ parser.add_argument(
+ "--fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to perform inference in fp16",
+ )
+ parser.add_argument(
+ "--compression-ratio-threshold",
+ type=optional_float,
+ default=2.4,
+ help="if the gzip compression ratio is higher than this value, treat the decoding as failed",
+ )
+ parser.add_argument(
+ "--logprob-threshold",
+ type=optional_float,
+ default=-1.0,
+ help="If the average log probability is lower than this value, treat the decoding as failed",
+ )
+ parser.add_argument(
+ "--no-speech-threshold",
+ type=optional_float,
+ default=0.6,
+ help="If the probability of the token is higher than this value the decoding has failed due to `logprob_threshold`, consider the segment as silence",
+ )
+ parser.add_argument(
+ "--word-timestamps",
+ type=str2bool,
+ default=False,
+ help="Extract word-level timestamps and refine the results based on them",
+ )
+ parser.add_argument(
+ "--prepend-punctuations",
+ type=str,
+ default="\"'“¿([{-",
+ help="If word-timestamps is True, merge these punctuation symbols with the next word",
+ )
+ parser.add_argument(
+ "--append-punctuations",
+ type=str,
+ default="\"'.。,,!!??::”)]}、",
+ help="If word_timestamps is True, merge these punctuation symbols with the previous word",
+ )
+ parser.add_argument(
+ "--highlight-words",
+ type=str2bool,
+ default=False,
+ help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
+ )
+ parser.add_argument(
+ "--max-line-width",
+ type=int,
+ default=None,
+ help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
+ )
+ parser.add_argument(
+ "--max-line-count",
+ type=int,
+ default=None,
+ help="(requires --word_timestamps True) the maximum number of lines in a segment",
+ )
+ parser.add_argument(
+ "--max-words-per-line",
+ type=int,
+ default=None,
+ help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
+ )
+ parser.add_argument(
+ "--hallucination-silence-threshold",
+ type=optional_float,
+ help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
+ )
+ parser.add_argument(
+ "--clip-timestamps",
+ type=str,
+ default="0",
+ help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
+ )
+ return parser
+
+
+def main():
+ parser = build_parser()
+ args = vars(parser.parse_args())
+ if args["verbose"] is True:
+ print(f"Args: {args}")
+
+ path_or_hf_repo: str = args.pop("model")
+ output_dir: str = args.pop("output_dir")
+ output_format: str = args.pop("output_format")
+ os.makedirs(output_dir, exist_ok=True)
+
+ writer = get_writer(output_format, output_dir)
+ word_options = [
+ "highlight_words",
+ "max_line_count",
+ "max_line_width",
+ "max_words_per_line",
+ ]
+ writer_args = {arg: args.pop(arg) for arg in word_options}
+ if not args["word_timestamps"]:
+ for k, v in writer_args.items():
+ if v:
+ argop = k.replace("_", "-")
+ parser.error(f"--{argop} requires --word-timestamps True")
+ if writer_args["max_line_count"] and not writer_args["max_line_width"]:
+ warnings.warn("--max-line-count has no effect without --max-line-width")
+ if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
+ warnings.warn("--max-words-per-line has no effect with --max-line-width")
+ for audio_path in args.pop("audio"):
+ try:
+ result = transcribe(
+ audio_path,
+ path_or_hf_repo=path_or_hf_repo,
+ **args,
+ )
+ writer(result, audio_path, **writer_args)
+ except Exception as e:
+ traceback.print_exc()
+ print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py
index 13c36315..04915deb 100644
--- a/whisper/mlx_whisper/timing.py
+++ b/whisper/mlx_whisper/timing.py
@@ -276,7 +276,7 @@ def add_word_timestamps(
word=timing.word,
start=round(time_offset + timing.start, 2),
end=round(time_offset + timing.end, 2),
- probability=timing.probability,
+ probability=float(timing.probability),
)
)
diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py
index ae3cfb71..67c7397c 100644
--- a/whisper/mlx_whisper/version.py
+++ b/whisper/mlx_whisper/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.2.0"
+__version__ = "0.3.0"
diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py
new file mode 100644
index 00000000..464ead18
--- /dev/null
+++ b/whisper/mlx_whisper/writers.py
@@ -0,0 +1,272 @@
+# Copyright © 2024 Apple Inc.
+
+import json
+import os
+import re
+import sys
+import zlib
+from typing import Callable, List, Optional, TextIO
+
+
+def format_timestamp(
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
+):
+ assert seconds >= 0, "non-negative timestamp expected"
+ milliseconds = round(seconds * 1000.0)
+
+ hours = milliseconds // 3_600_000
+ milliseconds -= hours * 3_600_000
+
+ minutes = milliseconds // 60_000
+ milliseconds -= minutes * 60_000
+
+ seconds = milliseconds // 1_000
+ milliseconds -= seconds * 1_000
+
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+ return (
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+ )
+
+
+def get_start(segments: List[dict]) -> Optional[float]:
+ return next(
+ (w["start"] for s in segments for w in s["words"]),
+ segments[0]["start"] if segments else None,
+ )
+
+
+class ResultWriter:
+ extension: str
+
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+
+ def __call__(
+ self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
+ ):
+ audio_basename = os.path.basename(audio_path)
+ audio_basename = os.path.splitext(audio_basename)[0]
+ output_path = os.path.join(
+ self.output_dir, audio_basename + "." + self.extension
+ )
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ self.write_result(result, file=f, options=options, **kwargs)
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ raise NotImplementedError
+
+
+class WriteTXT(ResultWriter):
+ extension: str = "txt"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for segment in result["segments"]:
+ print(segment["text"].strip(), file=file, flush=True)
+
+
+class SubtitlesWriter(ResultWriter):
+ always_include_hours: bool
+ decimal_marker: str
+
+ def iterate_result(
+ self,
+ result: dict,
+ options: Optional[dict] = None,
+ *,
+ max_line_width: Optional[int] = None,
+ max_line_count: Optional[int] = None,
+ highlight_words: bool = False,
+ max_words_per_line: Optional[int] = None,
+ ):
+ options = options or {}
+ max_line_width = max_line_width or options.get("max_line_width")
+ max_line_count = max_line_count or options.get("max_line_count")
+ highlight_words = highlight_words or options.get("highlight_words", False)
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
+ preserve_segments = max_line_count is None or max_line_width is None
+ max_line_width = max_line_width or 1000
+ max_words_per_line = max_words_per_line or 1000
+
+ def iterate_subtitles():
+ line_len = 0
+ line_count = 1
+ # the next subtitle to yield (a list of word timings with whitespace)
+ subtitle: List[dict] = []
+ last: float = get_start(result["segments"]) or 0.0
+ for segment in result["segments"]:
+ chunk_index = 0
+ words_count = max_words_per_line
+ while chunk_index < len(segment["words"]):
+ remaining_words = len(segment["words"]) - chunk_index
+ if max_words_per_line > len(segment["words"]) - chunk_index:
+ words_count = remaining_words
+ for i, original_timing in enumerate(
+ segment["words"][chunk_index : chunk_index + words_count]
+ ):
+ timing = original_timing.copy()
+ long_pause = (
+ not preserve_segments and timing["start"] - last > 3.0
+ )
+ has_room = line_len + len(timing["word"]) <= max_line_width
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
+ if (
+ line_len > 0
+ and has_room
+ and not long_pause
+ and not seg_break
+ ):
+ # line continuation
+ line_len += len(timing["word"])
+ else:
+ # new line
+ timing["word"] = timing["word"].strip()
+ if (
+ len(subtitle) > 0
+ and max_line_count is not None
+ and (long_pause or line_count >= max_line_count)
+ or seg_break
+ ):
+ # subtitle break
+ yield subtitle
+ subtitle = []
+ line_count = 1
+ elif line_len > 0:
+ # line break
+ line_count += 1
+ timing["word"] = "\n" + timing["word"]
+ line_len = len(timing["word"].strip())
+ subtitle.append(timing)
+ last = timing["start"]
+ chunk_index += max_words_per_line
+ if len(subtitle) > 0:
+ yield subtitle
+
+ if len(result["segments"]) > 0 and "words" in result["segments"][0]:
+ for subtitle in iterate_subtitles():
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
+ subtitle_text = "".join([word["word"] for word in subtitle])
+ if highlight_words:
+ last = subtitle_start
+ all_words = [timing["word"] for timing in subtitle]
+ for i, this_word in enumerate(subtitle):
+ start = self.format_timestamp(this_word["start"])
+ end = self.format_timestamp(this_word["end"])
+ if last != start:
+ yield last, start, subtitle_text
+
+ yield start, end, "".join(
+ [
+ (
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ )
+ for j, word in enumerate(all_words)
+ ]
+ )
+ last = end
+ else:
+ yield subtitle_start, subtitle_end, subtitle_text
+ else:
+ for segment in result["segments"]:
+ segment_start = self.format_timestamp(segment["start"])
+ segment_end = self.format_timestamp(segment["end"])
+ segment_text = segment["text"].strip().replace("-->", "->")
+ yield segment_start, segment_end, segment_text
+
+ def format_timestamp(self, seconds: float):
+ return format_timestamp(
+ seconds=seconds,
+ always_include_hours=self.always_include_hours,
+ decimal_marker=self.decimal_marker,
+ )
+
+
+class WriteVTT(SubtitlesWriter):
+ extension: str = "vtt"
+ always_include_hours: bool = False
+ decimal_marker: str = "."
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ print("WEBVTT\n", file=file)
+ for start, end, text in self.iterate_result(result, options, **kwargs):
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
+
+
+class WriteSRT(SubtitlesWriter):
+ extension: str = "srt"
+ always_include_hours: bool = True
+ decimal_marker: str = ","
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for i, (start, end, text) in enumerate(
+ self.iterate_result(result, options, **kwargs), start=1
+ ):
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
+
+
+class WriteTSV(ResultWriter):
+ """
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
+ \t\t
+
+ Using integer milliseconds as start and end times means there's no chance of interference from
+ an environment setting a language encoding that causes the decimal in a floating point number
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
+ """
+
+ extension: str = "tsv"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ print("start", "end", "text", sep="\t", file=file)
+ for segment in result["segments"]:
+ print(round(1000 * segment["start"]), file=file, end="\t")
+ print(round(1000 * segment["end"]), file=file, end="\t")
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
+
+
+class WriteJSON(ResultWriter):
+ extension: str = "json"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ json.dump(result, file, ensure_ascii=False)
+
+
+def get_writer(
+ output_format: str, output_dir: str
+) -> Callable[[dict, TextIO, dict], None]:
+ writers = {
+ "txt": WriteTXT,
+ "vtt": WriteVTT,
+ "srt": WriteSRT,
+ "tsv": WriteTSV,
+ "json": WriteJSON,
+ }
+
+ if output_format == "all":
+ all_writers = [writer(output_dir) for writer in writers.values()]
+
+ def write_all(
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for writer in all_writers:
+ writer(result, file, options, **kwargs)
+
+ return write_all
+
+ return writers[output_format](output_dir)
diff --git a/whisper/setup.py b/whisper/setup.py
index c400a547..086f6471 100644
--- a/whisper/setup.py
+++ b/whisper/setup.py
@@ -29,4 +29,9 @@ setup(
packages=find_namespace_packages(),
include_package_data=True,
python_requires=">=3.8",
+ entry_points={
+ "console_scripts": [
+ "mlx_whisper = mlx_whisper.cli:main",
+ ]
+ },
)
From 7be292c0c9bca9651fce1c93f4b20ee7fb7b1e82 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Fri, 16 Aug 2024 15:28:39 -0700
Subject: [PATCH 012/188] Handle longer prompt/generation (#931)
* rebase
* nits
* nit
* fix rotating cache with step prefill
* update version
---
llms/mlx_lm/generate.py | 8 +-
llms/mlx_lm/models/base.py | 112 ++++++++++++++++++++++++--
llms/mlx_lm/models/cohere.py | 2 +
llms/mlx_lm/models/dbrx.py | 2 +
llms/mlx_lm/models/deepseek_v2.py | 2 +
llms/mlx_lm/models/gemma.py | 2 +
llms/mlx_lm/models/gemma2.py | 2 +
llms/mlx_lm/models/gpt2.py | 2 +
llms/mlx_lm/models/gpt_bigcode.py | 2 +
llms/mlx_lm/models/gpt_neox.py | 2 +
llms/mlx_lm/models/internlm2.py | 2 +
llms/mlx_lm/models/llama.py | 2 +
llms/mlx_lm/models/minicpm.py | 2 +
llms/mlx_lm/models/mixtral.py | 2 +
llms/mlx_lm/models/olmo.py | 2 +
llms/mlx_lm/models/openelm.py | 2 +
llms/mlx_lm/models/phi.py | 2 +
llms/mlx_lm/models/phi3.py | 2 +
llms/mlx_lm/models/phi3small.py | 2 +
llms/mlx_lm/models/phixtral.py | 2 +
llms/mlx_lm/models/plamo.py | 2 +
llms/mlx_lm/models/qwen.py | 2 +
llms/mlx_lm/models/qwen2.py | 2 +
llms/mlx_lm/models/qwen2_moe.py | 2 +
llms/mlx_lm/models/recurrent_gemma.py | 8 ++
llms/mlx_lm/models/stablelm.py | 2 +
llms/mlx_lm/models/starcoder2.py | 2 +
llms/mlx_lm/models/su_rope.py | 2 +
llms/mlx_lm/models/switch_layers.py | 2 +
llms/mlx_lm/utils.py | 26 +++++-
llms/mlx_lm/version.py | 2 +-
llms/tests/test_models.py | 60 +++++++++++++-
32 files changed, 255 insertions(+), 13 deletions(-)
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index c003940b..6707d25c 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -76,7 +76,12 @@ def setup_arg_parser():
type=int,
default=None,
help="Set the MLX cache limit in GB",
- required=False,
+ )
+ parser.add_argument(
+ "--max-kv-size",
+ type=int,
+ default=1024,
+ help="Set the maximum key-value cache size",
)
return parser
@@ -154,6 +159,7 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
+ max_kv_size=args.max_kv_size,
)
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index 3fe276d2..3e84554c 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -1,6 +1,8 @@
+# Copyright © 2023-2024 Apple Inc.
+
import inspect
from dataclasses import dataclass
-from typing import List, Optional
+from typing import Any, List, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -44,6 +46,100 @@ class KVCache:
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+ def state(self):
+ return self.keys, self.values
+
+
+class RotatingKVCache:
+
+ def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
+ self.n_kv_heads = n_kv_heads
+ if isinstance(head_dim, int):
+ self.k_head_dim = self.v_head_dim = head_dim
+ elif isinstance(head_dim, tuple) and len(head_dim) == 2:
+ self.k_head_dim, self.v_head_dim = head_dim
+ else:
+ raise ValueError("head_dim must be an int or a tuple of two ints")
+ self.keep = keep
+ self.keys = None
+ self.values = None
+ self.offset = 0
+ self.max_size = max_size
+ self.step = step
+ self._idx = 0
+
+ def _trim(self, trim_size, v, append=None):
+ to_cat = []
+ if trim_size > 0:
+ to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
+ else:
+ to_cat = [v]
+ if append is not None:
+ to_cat.append(append)
+ return mx.concatenate(to_cat, axis=2)
+
+ def update_and_fetch(self, keys, values):
+ prev = self.offset
+ B, _, S = keys.shape[:3]
+
+ # Prefill mode
+ if S > 1:
+ if self.keys is None:
+ self.keys = keys
+ self.values = values
+ else:
+ # The largest size is self.max_size + S - 1 to ensure
+ # every token gets at least self.max_size context
+ trim_size = self.keys.shape[2] - self.max_size + 1
+ self.keys = self._trim(trim_size, self.keys, keys)
+ self.values = self._trim(trim_size, self.values, values)
+ self.offset += S
+ self._idx = self.keys.shape[2]
+ return self.keys, self.values
+
+ # Generation mode
+ # May not have hit the max size yet, so potentially
+ # keep growing the cache
+ if self.keys is None or (
+ prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
+ ):
+ new_size = min(self.step, self.max_size - prev)
+ k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
+ v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
+ new_k = mx.zeros(k_shape, keys.dtype)
+ new_v = mx.zeros(v_shape, values.dtype)
+ if self.keys is not None:
+ self.keys = mx.concatenate([self.keys, new_k], axis=2)
+ self.values = mx.concatenate([self.values, new_v], axis=2)
+ else:
+ self.keys, self.values = new_k, new_v
+ self._idx = prev
+
+ # Trim if needed
+ trim_size = self.keys.shape[2] - self.max_size
+ if trim_size > 0:
+ self.keys = self._trim(trim_size, self.keys)
+ self.values = self._trim(trim_size, self.values)
+ self._idx = self.max_size
+
+ # Rotate
+ if self._idx == self.max_size:
+ self._idx = self.keep
+
+ # Assign
+ self.keys[..., self._idx : self._idx + 1, :] = keys
+ self.values[..., self._idx : self._idx + 1, :] = values
+ self.offset += 1
+ self._idx += 1
+
+ # If the buffer is not full, slice off the end
+ if self.offset < self.max_size:
+ return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+ return self.keys, self.values
+
+ def state(self):
+ return self.keys, self.values
+
@dataclass
class BaseModelArgs:
@@ -65,13 +161,17 @@ def create_additive_causal_mask(N: int, offset: int = 0):
return mask * -1e9
-def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
+def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
- # Input consists of multiple tokens, create a causal mask so that prior
- # tokens do not give attention to later tokens. If a cache is in place
- # (because e.g. prompt reuse), offset the mask accordingly.
- offset = cache[0].offset if cache is not None and cache[0] is not None else 0
+ if cache is not None and cache[0] is not None:
+ c = cache[0]
+ if isinstance(c, RotatingKVCache):
+ offset = min(c.max_size - 1, c.offset)
+ else:
+ offset = c.offset
+ else:
+ offset = 0
mask = create_additive_causal_mask(T, offset)
mask = mask.astype(h.dtype)
else:
diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py
index 7dc2b9bf..cfcf2945 100644
--- a/llms/mlx_lm/models/cohere.py
+++ b/llms/mlx_lm/models/cohere.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Optional, Tuple
diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py
index 7a2a7a7d..f0214549 100644
--- a/llms/mlx_lm/models/dbrx.py
+++ b/llms/mlx_lm/models/dbrx.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Optional, Tuple
diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py
index bd743e53..f320b564 100644
--- a/llms/mlx_lm/models/deepseek_v2.py
+++ b/llms/mlx_lm/models/deepseek_v2.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py
index 323ebaa6..c6150284 100644
--- a/llms/mlx_lm/models/gemma.py
+++ b/llms/mlx_lm/models/gemma.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Optional, Tuple
diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py
index d4bd8a5d..1d410a15 100644
--- a/llms/mlx_lm/models/gemma2.py
+++ b/llms/mlx_lm/models/gemma2.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Optional, Tuple
diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py
index 81f71cac..8a770936 100644
--- a/llms/mlx_lm/models/gpt2.py
+++ b/llms/mlx_lm/models/gpt2.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py
index a5336203..652eb9e4 100644
--- a/llms/mlx_lm/models/gpt_bigcode.py
+++ b/llms/mlx_lm/models/gpt_bigcode.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py
index 1d2f74b7..c2aaa9ea 100644
--- a/llms/mlx_lm/models/gpt_neox.py
+++ b/llms/mlx_lm/models/gpt_neox.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py
index 2ee2af2d..bcc0cf0c 100644
--- a/llms/mlx_lm/models/internlm2.py
+++ b/llms/mlx_lm/models/internlm2.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py
index 2f323245..192e591f 100644
--- a/llms/mlx_lm/models/llama.py
+++ b/llms/mlx_lm/models/llama.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py
index a3d01cbb..df0670be 100644
--- a/llms/mlx_lm/models/minicpm.py
+++ b/llms/mlx_lm/models/minicpm.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py
index c7d8c5c5..2db57752 100644
--- a/llms/mlx_lm/models/mixtral.py
+++ b/llms/mlx_lm/models/mixtral.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py
index 8a28ad74..59849c96 100644
--- a/llms/mlx_lm/models/olmo.py
+++ b/llms/mlx_lm/models/olmo.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from sys import exit
from typing import Optional, Tuple
diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py
index 3f0d2605..19d3c027 100644
--- a/llms/mlx_lm/models/openelm.py
+++ b/llms/mlx_lm/models/openelm.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py
index 520ac1ad..fd3fd709 100644
--- a/llms/mlx_lm/models/phi.py
+++ b/llms/mlx_lm/models/phi.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from typing import Tuple
diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py
index 2536aacb..f8facdb1 100644
--- a/llms/mlx_lm/models/phi3.py
+++ b/llms/mlx_lm/models/phi3.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py
index de075652..665dbc73 100644
--- a/llms/mlx_lm/models/phi3small.py
+++ b/llms/mlx_lm/models/phi3small.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from functools import partial
diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py
index f0aef0c9..bb67615d 100644
--- a/llms/mlx_lm/models/phixtral.py
+++ b/llms/mlx_lm/models/phixtral.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import inspect
import math
from dataclasses import dataclass
diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py
index 47a9ea4f..5d2b7586 100644
--- a/llms/mlx_lm/models/plamo.py
+++ b/llms/mlx_lm/models/plamo.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py
index 67816599..6d2c7bbf 100644
--- a/llms/mlx_lm/models/qwen.py
+++ b/llms/mlx_lm/models/qwen.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Tuple
diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py
index cb8268aa..b3ce02a3 100644
--- a/llms/mlx_lm/models/qwen2.py
+++ b/llms/mlx_lm/models/qwen2.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py
index 121ab813..ff7831f3 100644
--- a/llms/mlx_lm/models/qwen2_moe.py
+++ b/llms/mlx_lm/models/qwen2_moe.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py
index 428431e3..34750ace 100644
--- a/llms/mlx_lm/models/recurrent_gemma.py
+++ b/llms/mlx_lm/models/recurrent_gemma.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from typing import List, Literal, Optional
@@ -53,6 +55,9 @@ class RecurrentCache:
def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state)
+ def state(self):
+ return self._cache
+
class WindowKVCache:
@@ -80,6 +85,9 @@ class WindowKVCache:
self.values = _update(self.values, values)
return self.keys, self.values
+ def state(self):
+ return self.keys, self.values
+
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py
index 9b4d043c..b340de28 100644
--- a/llms/mlx_lm/models/stablelm.py
+++ b/llms/mlx_lm/models/stablelm.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from dataclasses import dataclass
from typing import Tuple
diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py
index a6eb5377..9cec0e39 100644
--- a/llms/mlx_lm/models/starcoder2.py
+++ b/llms/mlx_lm/models/starcoder2.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
from dataclasses import dataclass
from typing import Optional, Tuple
diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py
index cdf6ceaf..2ee20a63 100644
--- a/llms/mlx_lm/models/su_rope.py
+++ b/llms/mlx_lm/models/su_rope.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
from typing import List, Union
diff --git a/llms/mlx_lm/models/switch_layers.py b/llms/mlx_lm/models/switch_layers.py
index 00aa65d8..4a157473 100644
--- a/llms/mlx_lm/models/switch_layers.py
+++ b/llms/mlx_lm/models/switch_layers.py
@@ -1,3 +1,5 @@
+# Copyright © 2023-2024 Apple Inc.
+
import math
import mlx.core as mx
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index e7a9dba8..44196766 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
-from .models.base import KVCache
+from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
@@ -136,6 +136,8 @@ def generate_step(
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
+ prefill_step_size: int = 512,
+ max_kv_size: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -156,6 +158,9 @@ def generate_step(
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
+ prefill_step_size (int): Step size for processing the prompt.
+ max_kv_size (int, optional): Maximum size of the key-value cache. Old
+ entries (except the first 4 tokens) will be overwritten.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -197,7 +202,13 @@ def generate_step(
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
- cache = [KVCache(model.head_dim, n) for n in kv_heads]
+ if max_kv_size is not None:
+ cache = [
+ RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
+ for n in kv_heads
+ ]
+ else:
+ cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist()
@@ -223,6 +234,11 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
return y, logprobs.squeeze(0)
+ while y.size > prefill_step_size:
+ model(y[:prefill_step_size][None], cache=cache)
+ mx.eval([c.state for c in cache])
+ y = y[prefill_step_size:]
+
y, logprobs = _step(y)
mx.async_eval(y)
@@ -343,8 +359,10 @@ def generate(
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
- print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
- print(f"Generation: {gen_tps:.3f} tokens-per-sec")
+ print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
+ print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
+ peak_mem = mx.metal.get_peak_memory() / 2**30
+ print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py
index 40b73ede..f73aaa0a 100644
--- a/llms/mlx_lm/version.py
+++ b/llms/mlx_lm/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.16.0"
+__version__ = "0.17.0"
diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py
index 19341981..fcf1dc33 100644
--- a/llms/tests/test_models.py
+++ b/llms/tests/test_models.py
@@ -4,7 +4,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.base import KVCache, RotatingKVCache
class TestModels(unittest.TestCase):
@@ -29,6 +29,64 @@ class TestModels(unittest.TestCase):
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)
+ def test_rotating_kv_cache(self):
+ b, h, d = 1, 2, 32
+ cache = RotatingKVCache(d, h, max_size=8, step=4)
+
+ k = mx.random.uniform(shape=(b, h, 2, d))
+ v = mx.random.uniform(shape=(b, h, 2, d))
+
+ k_up, v_up = cache.update_and_fetch(k, v)
+ self.assertTrue(mx.array_equal(k_up, k))
+ self.assertTrue(mx.array_equal(v_up, v))
+ self.assertEqual(cache.offset, 2)
+
+ k = mx.random.uniform(shape=(b, h, 5, d))
+ v = mx.random.uniform(shape=(b, h, 5, d))
+ k_up, v_up = cache.update_and_fetch(k, v)
+ self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
+ self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
+
+ k = mx.random.uniform(shape=(b, h, 4, d))
+ v = mx.random.uniform(shape=(b, h, 4, d))
+ k_up, v_up = cache.update_and_fetch(k, v)
+ self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
+ self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
+
+ idx = 0
+ for _ in range(10):
+ k = mx.random.uniform(shape=(b, h, 1, d))
+ v = mx.random.uniform(shape=(b, h, 1, d))
+ k_up, v_up = cache.update_and_fetch(k, v)
+ self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
+ self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
+ idx += 1
+ idx %= 8
+
+ # Try with nonzero keep
+ cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
+
+ # Check a large update
+ k = mx.random.uniform(shape=(b, h, 20, d))
+ v = mx.random.uniform(shape=(b, h, 20, d))
+ k_up, v_up = cache.update_and_fetch(k, v)
+ self.assertTrue(mx.array_equal(k_up, k))
+ self.assertTrue(mx.array_equal(v_up, v))
+
+ # A bunch of small updates
+ self.assertEqual(cache.offset, 20)
+ idx = 2
+ for i in range(10):
+ k = mx.random.uniform(shape=(b, h, 1, d))
+ v = mx.random.uniform(shape=(b, h, 1, d))
+ k_up, v_up = cache.update_and_fetch(k, v)
+ self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
+ self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
+ self.assertEqual(cache.offset, 21 + i)
+ idx += 1
+ if idx >= 8:
+ idx = 2
+
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
From 0164d2058bf43197d4fe04bf26afa22cda78b7f0 Mon Sep 17 00:00:00 2001
From: L
Date: Sat, 17 Aug 2024 23:18:09 +0900
Subject: [PATCH 013/188] feat: DeepSeek MoE v1 (#942)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feat: deepseek v1
DeepSeek is still releasing models on the DeepSeek V1 architecture.
```sh
mlx_lm.convert --hf-path deepseek-ai/DeepSeek-Prover-V1.5-RL --mlx-path DeepSeek-Prover-V1.5-RL-8bit --q-bits 8 -q
mlx_lm.generate --model DeepSeek-Prover-V1.5-RL-8bit --ignore-chat-template --max-tokens 512 --prompt 'import Mathlib
import Aesop
set_option maxHeartbeats 0
open BigOperators Real Nat Topology Rat
/-- The second and fourth terms of a geometric sequence are $2$ and $6$. Which of the following is a possible first term?
Show that it is $\frac{2\sqrt{3}}{3}$.-/
theorem amc12b_2003_p6 (a r : ℝ) (u : ℕ → ℝ) (h₀ : ∀ k, u k = a * r ^ k) (h₁ : u 1 = 2)
(h₂ : u 3 = 6) : u 0 = 2 / Real.sqrt 3 ∨ u 0 = -(2 / Real.sqrt 3) := by'
```
* nits
* nits
* nits
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/models/deepseek.py | 266 +++++++++++++++++++++++++++++++++
llms/mlx_lm/tuner/utils.py | 1 +
2 files changed, 267 insertions(+)
create mode 100644 llms/mlx_lm/models/deepseek.py
diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py
new file mode 100644
index 00000000..dcfa331c
--- /dev/null
+++ b/llms/mlx_lm/models/deepseek.py
@@ -0,0 +1,266 @@
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .base import BaseModelArgs, KVCache, create_attention_mask
+from .switch_layers import SwitchGLU
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str = "deepseek"
+ vocab_size: int = 102400
+ hidden_size: int = 4096
+ intermediate_size: int = 11008
+ moe_intermediate_size: int = 1407
+ num_hidden_layers: int = 30
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 32
+ n_shared_experts: Optional[int] = None
+ n_routed_experts: Optional[int] = None
+ num_experts_per_tok: Optional[int] = None
+ moe_layer_freq: int = 1
+ first_k_dense_replace: int = 0
+ max_position_embeddings: int = 2048
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 10000.0
+ rope_scaling: Optional[Dict] = None
+ attention_bias: bool = False
+
+
+class DeepseekAttention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.scale = self.head_dim**-0.5
+
+ attention_bias = getattr(config, "attention_bias", False)
+
+ self.q_proj = nn.Linear(
+ self.hidden_size,
+ config.num_attention_heads * self.head_dim,
+ bias=attention_bias,
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size,
+ config.num_key_value_heads * self.head_dim,
+ bias=attention_bias,
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size,
+ config.num_key_value_heads * self.head_dim,
+ bias=attention_bias,
+ )
+ self.o_proj = nn.Linear(
+ self.hidden_size,
+ config.num_attention_heads * self.head_dim,
+ bias=attention_bias,
+ )
+
+ rope_scale = 1.0
+ if config.rope_scaling and config.rope_scaling["type"] == "linear":
+ assert isinstance(config.rope_scaling["factor"], float)
+ rope_scale = 1 / config.rope_scaling["factor"]
+ self.rope = nn.RoPE(
+ self.head_dim,
+ base=config.rope_theta,
+ scale=rope_scale,
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ B, L, _ = x.shape
+
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+ queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
+ 0, 2, 1, 3
+ )
+ keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
+
+ if cache is not None:
+ queries = self.rope(queries, offset=cache.offset)
+ keys = self.rope(keys, offset=cache.offset)
+ keys, values = cache.update_and_fetch(keys, values)
+ else:
+ queries = self.rope(queries)
+ keys = self.rope(keys)
+
+ output = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+
+class DeepseekMLP(nn.Module):
+ def __init__(
+ self,
+ config: ModelArgs,
+ hidden_size: int | None = None,
+ intermediate_size: int | None = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = hidden_size or config.hidden_size
+ self.intermediate_size = intermediate_size or config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = nn.silu
+
+ def __call__(self, x: mx.array) -> mx.array:
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class MoEGate(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
+
+ def __call__(self, x):
+ gates = x @ self.weight.T
+ scores = mx.softmax(gates, axis=-1, precise=True)
+ k = self.top_k
+ inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
+ scores = mx.take_along_axis(scores, inds, axis=-1)
+ return inds, scores
+
+
+class DeepseekMoE(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.switch_mlp = SwitchGLU(
+ config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
+ )
+
+ self.gate = MoEGate(config)
+ if config.n_shared_experts is not None:
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
+ self.shared_experts = DeepseekMLP(
+ config=config, intermediate_size=intermediate_size
+ )
+
+ def __call__(self, x):
+ inds, scores = self.gate(x)
+ y = self.switch_mlp(x, inds)
+ y = (y * scores[..., None]).sum(axis=-2)
+ if self.config.n_shared_experts is not None:
+ y = y + self.shared_experts(x)
+
+ return y
+
+
+class DeepseekDecoderLayer(nn.Module):
+ def __init__(self, config: ModelArgs, layer_idx: int):
+ super().__init__()
+ self.self_attn = DeepseekAttention(config)
+ self.mlp = (
+ DeepseekMoE(config)
+ if (
+ config.n_routed_experts is not None
+ and layer_idx >= config.first_k_dense_replace
+ and layer_idx % config.moe_layer_freq == 0
+ )
+ else DeepseekMLP(config)
+ )
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = nn.RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
+ h = x + r
+ r = self.mlp(self.post_attention_layernorm(h))
+ out = h + r
+ return out
+
+
+class DeepseekModel(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = [
+ DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
+ ]
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def __call__(
+ self,
+ x: mx.array,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ h = self.embed_tokens(x)
+ mask = create_attention_mask(h, cache)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for layer, c in zip(self.layers, cache):
+ h = layer(h, mask, c)
+
+ return self.norm(h)
+
+
+class Model(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.args = config
+ self.model_type = config.model_type
+ self.model = DeepseekModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache: Optional[KVCache] = None,
+ ):
+ out = self.model(inputs, cache)
+ return self.lm_head(out)
+
+ def sanitize(self, weights):
+ for l in range(self.args.num_hidden_layers):
+ prefix = f"model.layers.{l}"
+ for m in ["gate_proj", "down_proj", "up_proj"]:
+ for k in ["weight", "scales", "biases"]:
+ if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
+ to_join = [
+ weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
+ for e in range(self.args.n_routed_experts)
+ ]
+ weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
+ return weights
+
+ @property
+ def layers(self):
+ return self.model.layers
+
+ @property
+ def head_dim(self):
+ return self.args.hidden_size // self.args.num_attention_heads
+
+ @property
+ def n_kv_heads(self):
+ return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index b7f1f9de..c6af9730 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -101,6 +101,7 @@ def linear_to_lora_layers(
"starcoder2",
"cohere",
"minicpm",
+ "deepseek",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type == "mixtral":
From 58591a1b4175265e7b316de5b4d0365be658c6b6 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Thu, 22 Aug 2024 10:41:21 -0700
Subject: [PATCH 014/188] fine tune deepseek (#932)
---
llms/mlx_lm/tuner/utils.py | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index c6af9730..9f18c2c0 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -128,6 +128,16 @@ def linear_to_lora_layers(
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
elif model.model_type == "internlm2":
keys = set(["attention.wqkv", "attention.wo"])
+ elif model.model_type == "deepseek_v2":
+ keys = set(
+ [
+ "self_attn.q_proj",
+ "self_attn.q_a_proj",
+ "self_attn.q_b_proj",
+ "self_attn.kv_a_proj_with_mqa",
+ "self_attn.kv_b_proj",
+ ]
+ )
else:
raise ValueError(f"Lora does not support {model.model_type}")
From 6731254e761f69d3c0925fd682a4d988fbf3fe7a Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Fri, 23 Aug 2024 13:18:51 -0700
Subject: [PATCH 015/188] Use fast rope (#945)
* use fast rope
* fix llama
* use fast rope for llama3.1
* requires unreleased mlx
* fix su
* fix deepseek v2
* only one of base or freqs
* nit
* fix
* hard code freqs
---
llms/mlx_lm/models/deepseek_v2.py | 91 +++++++----------------
llms/mlx_lm/models/llama.py | 43 ++++-------
llms/mlx_lm/models/phi3.py | 4 +-
llms/mlx_lm/models/su_rope.py | 51 ++++---------
llms/mlx_lm/requirements.txt | 2 +-
llms/mlx_lm/version.py | 2 +-
stable_diffusion/stable_diffusion/unet.py | 9 +--
7 files changed, 65 insertions(+), 137 deletions(-)
diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py
index f320b564..602a9710 100644
--- a/llms/mlx_lm/models/deepseek_v2.py
+++ b/llms/mlx_lm/models/deepseek_v2.py
@@ -68,13 +68,12 @@ def yarn_get_mscale(scale=1, mscale=1):
return 0.1 * mscale * math.log(scale) + 1.0
-def yarn_linear_ramp_mask(min, max, dim):
- if min == max:
- max += 0.001 # Prevent singularity
+def yarn_linear_ramp_mask(min_val, max_val, dim):
+ if min_val == max_val:
+ max_val += 0.001 # Prevent singularity
- linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
- ramp_func = mx.clip(linear_func, 0, 1)
- return ramp_func
+ linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
+ return mx.clip(linear_func, 0, 1)
class DeepseekV2YarnRotaryEmbedding(nn.Module):
@@ -91,72 +90,36 @@ class DeepseekV2YarnRotaryEmbedding(nn.Module):
mscale_all_dim=0,
):
super().__init__()
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- self.scaling_factor = scaling_factor
- self.original_max_position_embeddings = original_max_position_embeddings
- self.beta_fast = beta_fast
- self.beta_slow = beta_slow
- self.mscale = mscale
- self.mscale_all_dim = mscale_all_dim
-
- self.max_seq_len_cached = None
- self._cos_cached = None
- self._sin_cached = None
- self._inv_freq = None
- self.set_cos_sin_cache(max_position_embeddings)
-
- def set_cos_sin_cache(self, seq_len):
- self.max_seq_len_cached = seq_len
- dim = self.dim
- freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
- freq_inter = 1.0 / (
- self.scaling_factor
- * self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
+ self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
+ scaling_factor, mscale_all_dim
+ )
+ freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
+ freq_inter = scaling_factor * base ** (
+ mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
-
low, high = yarn_find_correction_range(
- self.beta_fast,
- self.beta_slow,
+ beta_fast,
+ beta_slow,
dim,
- self.base,
- self.original_max_position_embeddings,
+ base,
+ original_max_position_embeddings,
)
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
- inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
- self._inv_freq = inv_freq
-
- t = mx.arange(seq_len, dtype=mx.float32)
- freqs = mx.outer(t, inv_freq)
-
- mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
- self.scaling_factor, self.mscale_all_dim
+ freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
+ self._freqs = (freq_inter * freq_extra) / (
+ freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
- self._cos_cached = mx.cos(freqs) * mscale
- self._sin_cached = mx.sin(freqs) * mscale
-
- def apply_rotary_pos_emb(self, x, cos, sin):
- x1 = x[..., ::2]
- x2 = x[..., 1::2]
- rx1 = x1 * cos - x2 * sin
- rx2 = x1 * sin + x2 * cos
- return mx.concatenate([rx1, rx2], axis=-1)
-
def __call__(self, x, offset=0):
- seq_len = offset + x.shape[2]
- if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
- self.set_cos_sin_cache(seq_len=seq_len)
-
- if self._cos_cached.dtype != x.dtype:
- self._cos_cached = self._cos_cached.astype(x.dtype)
- self._sin_cached = self._sin_cached.astype(x.dtype)
-
- return self.apply_rotary_pos_emb(
+ if self.mscale != 1.0:
+ x = self.mscale * x
+ return mx.fast.rope(
x,
- self._cos_cached[offset:seq_len],
- self._sin_cached[offset:seq_len],
+ x.shape[-1],
+ traditional=True,
+ base=None,
+ scale=1.0,
+ offset=offset,
+ freqs=self._freqs,
)
diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py
index 192e591f..c4a947a5 100644
--- a/llms/mlx_lm/models/llama.py
+++ b/llms/mlx_lm/models/llama.py
@@ -65,19 +65,16 @@ class DynamicNTKScalingRoPE(nn.Module):
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
- self.original_base = base
self.scale = scale
self.rope_type = rope_type
self.rope_scaling = rope_scaling
- self.base = self.compute_base_freq()
+ self.base = base
+ self.compute_freqs()
- def compute_base_freq(self):
- if self.rope_type == "llama3":
- return self.compute_llama3_base_freq()
- return self.original_base
-
- # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
- def compute_llama3_base_freq(self):
+ def compute_freqs(self):
+ if self.rope_type != "llama3":
+ self._freqs = None
+ return
factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
@@ -89,19 +86,17 @@ class DynamicNTKScalingRoPE(nn.Module):
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
- freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
+ freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
- new_base_freqs = []
- smooths = (wavelens - high_freq_wavelen) / (
- low_freq_wavelen - high_freq_wavelen
+ freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
+ is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
+ smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
)
- new_base_freqs = freqs * (1 - smooths) * factor + smooths
- new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
- new_base_freqs = mx.where(
- wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
- )
- return new_base_freqs.mean().item()
+ smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
+ self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
+ self.base = None
def extra_repr(self):
return (
@@ -111,20 +106,14 @@ class DynamicNTKScalingRoPE(nn.Module):
)
def __call__(self, x, offset: int = 0):
- seq_len = x.shape[1] + offset
- base = self.base
- if self.max_position_embeddings and seq_len > self.max_position_embeddings:
- base *= (
- (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
- ) ** (self.dims / (self.dims - 2))
-
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
- base=base,
+ base=self.base,
scale=self.scale,
offset=offset,
+ freqs=self._freqs,
)
diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py
index f8facdb1..112ade7d 100644
--- a/llms/mlx_lm/models/phi3.py
+++ b/llms/mlx_lm/models/phi3.py
@@ -59,19 +59,17 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
- rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
- traditional=False,
base=args.rope_theta,
- scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
+ rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py
index 2ee20a63..f96b9957 100644
--- a/llms/mlx_lm/models/su_rope.py
+++ b/llms/mlx_lm/models/su_rope.py
@@ -4,15 +4,14 @@ import math
from typing import List, Union
import mlx.core as mx
+import mlx.nn as nn
-class SuScaledRotaryEmbedding:
+class SuScaledRotaryEmbedding(nn.Module):
def __init__(
self,
dims: int,
- traditional: bool = False,
base: float = 10000.0,
- scale: float = 1.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
@@ -23,10 +22,7 @@ class SuScaledRotaryEmbedding:
Args:
dims (int): The feature dimensions to be rotated.
- traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling.
- scale (float, optional): The scale used to scale the positions.
- Default: ``1.0``.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
@@ -42,40 +38,23 @@ class SuScaledRotaryEmbedding:
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
"""
- self.inv_freq_short = 1.0 / (
- mx.array(short_factor, dtype=mx.float32)
- * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
- )
- self.inv_freq_long = 1.0 / (
- scale
- * mx.array(long_factor, dtype=mx.float32)
- * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
- )
+ super().__init__()
+ freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
+ self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
- self.scaling_factor = math.sqrt(
+ self.scale = math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
- def _get_cos_sin(self, offset, L):
- position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
- inv_freq = (
- self.inv_freq_long
- if (offset + L) > self.original_max_position_embeddings
- else self.inv_freq_short
- )
- freqs = position_ids[:, None] * inv_freq[None, :]
- emb = mx.concatenate([freqs, freqs], axis=-1)
- cos = mx.cos(emb) * self.scaling_factor
- sin = mx.sin(emb) * self.scaling_factor
- return cos, sin
-
def __call__(self, x, offset: int = 0):
- def _rotate_half(_x):
- midpoint = _x.shape[-1] // 2
- x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
- return mx.concatenate([-x2, x1], axis=-1)
-
- cos, sin = self._get_cos_sin(offset, x.shape[2])
- return (x * cos) + (_rotate_half(x) * sin)
+ return mx.fast.rope(
+ self.scale * x,
+ x.shape[-1],
+ traditional=False,
+ base=None,
+ scale=1.0,
+ offset=offset,
+ freqs=self._freqs,
+ )
diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt
index 4875f931..814c03cc 100644
--- a/llms/mlx_lm/requirements.txt
+++ b/llms/mlx_lm/requirements.txt
@@ -1,4 +1,4 @@
-mlx>=0.14.1
+mlx>=0.17.0
numpy
transformers[sentencepiece]>=4.39.3
protobuf
diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py
index f73aaa0a..41237905 100644
--- a/llms/mlx_lm/version.py
+++ b/llms/mlx_lm/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.17.0"
+__version__ = "0.17.1"
diff --git a/stable_diffusion/stable_diffusion/unet.py b/stable_diffusion/stable_diffusion/unet.py
index ec2915e5..cfad7fcc 100644
--- a/stable_diffusion/stable_diffusion/unet.py
+++ b/stable_diffusion/stable_diffusion/unet.py
@@ -110,7 +110,7 @@ class Transformer2D(nn.Module):
# Perform the input norm and projection
B, H, W, C = x.shape
- x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
+ x = self.norm(x).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
@@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module):
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
- y = self.norm1(x.astype(mx.float32)).astype(dtype)
+ y = self.norm1(x)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
- y = self.norm2(y.astype(mx.float32)).astype(dtype)
+ y = self.norm2(y)
y = nn.silu(y)
y = self.conv2(y)
@@ -453,8 +453,7 @@ class UNetModel(nn.Module):
)
# Postprocess the output
- dtype = x.dtype
- x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
+ x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
From b5e18ef1e399f81baa6e92b1e25c2d8fe0394f42 Mon Sep 17 00:00:00 2001
From: Prince Canuma
Date: Sat, 24 Aug 2024 15:52:33 +0200
Subject: [PATCH 016/188] Add Phi-3.5-MoE (#946)
* add phimoe
* add phimoe to tunner
* add switch_mlp
* fix SuScaled args
* nits
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/models/phimoe.py | 218 ++++++++++++++++++++++++++++++++++
llms/mlx_lm/models/su_rope.py | 6 +-
llms/mlx_lm/tuner/utils.py | 3 +-
3 files changed, 225 insertions(+), 2 deletions(-)
create mode 100644 llms/mlx_lm/models/phimoe.py
diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py
new file mode 100644
index 00000000..db6bd4b5
--- /dev/null
+++ b/llms/mlx_lm/models/phimoe.py
@@ -0,0 +1,218 @@
+# Copyright © 2024 Apple Inc.
+import math
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .base import BaseModelArgs, create_attention_mask
+from .su_rope import SuScaledRotaryEmbedding
+from .switch_layers import SwitchGLU
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str = "phimoe"
+ vocab_size: int = 32064
+ hidden_size: int = 4096
+ intermediate_size: int = 6400
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 131072
+ original_max_position_embeddings: int = 4096
+ rms_norm_eps: float = 1e-6
+ rope_scaling: Dict[str, Union[float, List[float]]] = None
+ num_local_experts: int = 16
+ num_experts_per_tok: int = 2
+ rope_theta: float = 10000.0
+
+
+class Attention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+
+ dim = args.hidden_size
+ self.n_heads = n_heads = args.num_attention_heads
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
+
+ head_dim = args.hidden_size // n_heads
+ self.scale = head_dim**-0.5
+
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
+
+ self.rope = SuScaledRotaryEmbedding(
+ head_dim,
+ base=args.rope_theta,
+ max_position_embeddings=args.max_position_embeddings,
+ original_max_position_embeddings=args.original_max_position_embeddings,
+ short_factor=args.rope_scaling["short_factor"],
+ long_factor=args.rope_scaling["long_factor"],
+ short_mscale=args.rope_scaling["short_mscale"],
+ long_mscale=args.rope_scaling["long_mscale"],
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ ) -> mx.array:
+ B, L, D = x.shape
+
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+ # Prepare the queries, keys and values for the attention computation
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+ if cache is not None:
+ queries = self.rope(queries, offset=cache.offset)
+ keys = self.rope(keys, offset=cache.offset)
+ keys, values = cache.update_and_fetch(keys, values)
+ else:
+ queries = self.rope(queries)
+ keys = self.rope(keys)
+
+ output = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+
+class PhiMoESparseMoeBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.hidden_dim = args.hidden_size
+ self.ffn_dim = args.intermediate_size
+ self.num_experts = args.num_local_experts
+ self.top_k = args.num_experts_per_tok
+
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+ self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ gates = self.gate(x)
+
+ k = self.top_k
+ inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
+ scores = mx.take_along_axis(gates, inds, axis=-1)
+ scores = mx.softmax(scores, axis=-1, precise=True)
+
+ y = self.switch_mlp(x, inds)
+ y = (y * scores[..., None]).sum(axis=-2)
+
+ return y
+
+
+class PhiMoEDecoderLayer(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.hidden_size = args.hidden_size
+
+ self.self_attn = Attention(args)
+ self.block_sparse_moe = PhiMoESparseMoeBlock(args)
+ self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.post_attention_layernorm = nn.LayerNorm(
+ args.hidden_size, eps=args.rms_norm_eps
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ ) -> mx.array:
+ residual = x
+ hidden_states = self.input_layernorm(x)
+ hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.block_sparse_moe(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class PhiMoEModel(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.vocab_size = args.vocab_size
+
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
+ self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ ) -> mx.array:
+ h = self.embed_tokens(inputs)
+
+ mask = create_attention_mask(h, cache)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for layer, c in zip(self.layers, cache):
+ h = layer(h, mask, c)
+
+ return self.norm(h)
+
+
+class Model(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.model = PhiMoEModel(args)
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ ):
+ out = self.model(inputs, cache)
+ return self.lm_head(out)
+
+ def sanitize(self, weights):
+ if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
+ return weights
+ for l in range(self.args.num_hidden_layers):
+ prefix = f"model.layers.{l}"
+ for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+ for k in ["weight", "scales", "biases"]:
+ if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
+ to_join = [
+ weights.pop(
+ f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
+ )
+ for e in range(self.args.num_local_experts)
+ ]
+ weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
+ mx.stack(to_join)
+ )
+
+ return weights
+
+ @property
+ def layers(self):
+ return self.model.layers
+
+ @property
+ def head_dim(self):
+ return self.args.hidden_size // self.args.num_attention_heads
+
+ @property
+ def n_kv_heads(self):
+ return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py
index f96b9957..9c414afd 100644
--- a/llms/mlx_lm/models/su_rope.py
+++ b/llms/mlx_lm/models/su_rope.py
@@ -16,6 +16,8 @@ class SuScaledRotaryEmbedding(nn.Module):
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
long_factor: Union[List[float], float] = 1.0,
+ short_mscale: float = None,
+ long_mscale: float = None,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
@@ -37,12 +39,14 @@ class SuScaledRotaryEmbedding(nn.Module):
long_factor (float or list[float], optional): List of scaling
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
+ short_mscale (float, optional): Scale the input prior to embedding.
+ long_mscale (float, optional): Scale the input prior to embedding.
"""
super().__init__()
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
- self.scale = math.sqrt(
+ self.scale = long_mscale or math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index 9f18c2c0..1a54a925 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -96,6 +96,7 @@ def linear_to_lora_layers(
"stablelm",
"qwen2",
"qwen2_moe",
+ "phimoe",
"gemma",
"gemma2",
"starcoder2",
@@ -104,7 +105,7 @@ def linear_to_lora_layers(
"deepseek",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
- if model.model_type == "mixtral":
+ if model.model_type in ["mixtral", "phimoe"]:
keys.add("block_sparse_moe.gate")
if model.model_type == "qwen2_moe":
keys.add("mlp.gate")
From bf21789b17377bed99dbe846030f56ac810e0339 Mon Sep 17 00:00:00 2001
From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com>
Date: Mon, 26 Aug 2024 20:24:23 +0530
Subject: [PATCH 017/188] chore: update black pre-commit hooks to latest
versions (#955)
---
.pre-commit-config.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1a648a69..012a96ad 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
- rev: 24.3.0
+ rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
From 7f8c961287442ec1b78bedbba8e34251ec6defa2 Mon Sep 17 00:00:00 2001
From: Angelos Katharopoulos
Date: Wed, 28 Aug 2024 14:47:33 -0700
Subject: [PATCH 018/188] Fix setattr for the TokenizerWrapper (#961)
---
llms/mlx_lm/tokenizer_utils.py | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py
index 6caad629..04bbbcc5 100644
--- a/llms/mlx_lm/tokenizer_utils.py
+++ b/llms/mlx_lm/tokenizer_utils.py
@@ -252,9 +252,19 @@ class TokenizerWrapper:
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer
+ elif attr.startswith("_"):
+ return self.__getattribute__(attr)
else:
return getattr(self._tokenizer, attr)
+ def __setattr__(self, attr, value):
+ if attr == "detokenizer":
+ raise AttributeError("Cannot set the detokenizer.")
+ elif attr.startswith("_"):
+ super().__setattr__(attr, value)
+ else:
+ setattr(self._tokenizer, attr, value)
+
def _match(a, b):
if type(a) != type(b):
From 1003a8b2dd59a22255a6a6c9a20f9d41f9812fb5 Mon Sep 17 00:00:00 2001
From: Angelos Katharopoulos
Date: Wed, 28 Aug 2024 22:11:45 -0700
Subject: [PATCH 019/188] Add the ability to load the KV cache from a file
(#956)
---
llms/mlx_lm/cache_prompt.py | 149 ++++++++++++++++++++++++++++++++++++
llms/mlx_lm/generate.py | 69 +++++++++++++++--
llms/mlx_lm/models/base.py | 2 +
llms/mlx_lm/utils.py | 51 ++++++++----
llms/setup.py | 1 +
5 files changed, 250 insertions(+), 22 deletions(-)
create mode 100644 llms/mlx_lm/cache_prompt.py
diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py
new file mode 100644
index 00000000..ad045f1a
--- /dev/null
+++ b/llms/mlx_lm/cache_prompt.py
@@ -0,0 +1,149 @@
+# Copyright © 2024 Apple Inc.
+
+import argparse
+import json
+import sys
+import time
+
+import mlx.core as mx
+
+from .utils import load, make_kv_caches
+
+
+def setup_arg_parser():
+ """Set up and return the argument parser."""
+ parser = argparse.ArgumentParser(
+ description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="mlx_model",
+ help="The path to the local model directory or Hugging Face repo.",
+ )
+ parser.add_argument(
+ "--adapter-path",
+ type=str,
+ help="Optional path for the trained adapter weights and config.",
+ )
+ parser.add_argument(
+ "--trust-remote-code",
+ action="store_true",
+ help="Enable trusting remote code for tokenizer",
+ )
+ parser.add_argument(
+ "--eos-token",
+ type=str,
+ default=None,
+ help="End of sequence token for tokenizer",
+ )
+ parser.add_argument(
+ "--ignore-chat-template",
+ action="store_true",
+ help="Use the raw prompt without the tokenizer's chat template.",
+ )
+ parser.add_argument(
+ "--use-default-chat-template",
+ action="store_true",
+ help="Use the default chat template",
+ )
+ parser.add_argument(
+ "--cache-limit-gb",
+ type=int,
+ default=None,
+ help="Set the MLX cache limit in GB",
+ )
+ parser.add_argument(
+ "--max-kv-size",
+ type=int,
+ default=1024,
+ help="Set the maximum key-value cache size",
+ )
+ parser.add_argument(
+ "--kv-cache-file", help="The file to save the KV caches in", required=True
+ )
+ parser.add_argument(
+ "--prompt",
+ required=True,
+ help="Message to be processed by the model ('-' reads from stdin)",
+ )
+ return parser
+
+
+def main():
+ parser = setup_arg_parser()
+ args = parser.parse_args()
+
+ if args.cache_limit_gb is not None:
+ mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
+
+ # Building tokenizer_config
+ tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
+ if args.eos_token is not None:
+ tokenizer_config["eos_token"] = args.eos_token
+
+ model, tokenizer = load(
+ args.model,
+ adapter_path=args.adapter_path,
+ tokenizer_config=tokenizer_config,
+ )
+
+ args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
+
+ if args.use_default_chat_template:
+ if tokenizer.chat_template is None:
+ tokenizer.chat_template = tokenizer.default_chat_template
+
+ if not args.ignore_chat_template and (
+ hasattr(tokenizer, "apply_chat_template")
+ and tokenizer.chat_template is not None
+ ):
+ messages = [{"role": "user", "content": args.prompt}]
+ prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ # Treat the prompt as a prefix assuming that the suffix will be
+ # provided at generation time.
+ test_prompt = tokenizer.apply_chat_template(
+ [{"role": "user", "content": ""}],
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ n = len(test_prompt) - test_prompt.index("") - len("")
+ prompt = prompt[:-n]
+ else:
+ prompt = args.prompt
+
+ cache = make_kv_caches(model, args.max_kv_size)
+ y = mx.array(tokenizer.encode(prompt))
+
+ # Process the prompt
+ processed = 0
+ step_size = 512
+ start = time.time()
+ max_msg_len = 0
+ while y.size > 0:
+ model(y[:step_size][None], cache=cache)
+ mx.eval([c.state for c in cache])
+ processed += min(y.size, step_size)
+ y = y[step_size:]
+ current = time.time()
+ speed = processed / (current - start)
+ msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
+ max_msg_len = max(max_msg_len, len(msg))
+ print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
+ print()
+ print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
+
+ print("Saving...")
+ cache_dict = {}
+ for i, c in enumerate(cache):
+ cache_dict[f"{i}_keys"] = c.state[0]
+ cache_dict[f"{i}_values"] = c.state[1]
+ metadata = {}
+ metadata["model"] = args.model
+ metadata["chat_template"] = tokenizer.chat_template
+ metadata["tokenizer_config"] = json.dumps(tokenizer_config)
+ metadata["max_kv_size"] = str(args.max_kv_size)
+ mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 6707d25c..4aa4001a 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -1,17 +1,18 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
+import json
import mlx.core as mx
from .utils import generate, load
-DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
+DEFAULT_MAX_KV_SIZE = 1024
def setup_arg_parser():
@@ -20,7 +21,6 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
- default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
@@ -80,9 +80,14 @@ def setup_arg_parser():
parser.add_argument(
"--max-kv-size",
type=int,
- default=1024,
help="Set the maximum key-value cache size",
)
+ parser.add_argument(
+ "--kv-cache-file",
+ type=str,
+ default=None,
+ help="A file containing saved KV caches to avoid recomputing them",
+ )
return parser
@@ -113,6 +118,24 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
+def load_kv_cache_from_file(kv_cache_file):
+ if kv_cache_file is None:
+ return None, None
+
+ kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
+ cache_per_layer = {}
+ for k, x in kv_cache.items():
+ layer, kv_type = k.split("_")
+ if layer not in cache_per_layer:
+ cache_per_layer[layer] = {}
+ cache_per_layer[layer][kv_type] = x
+
+ cache_history = [None] * len(cache_per_layer)
+ for layer, c in cache_per_layer.items():
+ cache_history[int(layer)] = (c["keys"], c["values"])
+ return cache_history, metadata
+
+
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@@ -122,13 +145,25 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
+ # Load the kv cache and metadata if a kv cache file is provided
+ cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
+
# Building tokenizer_config
- tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
+ tokenizer_config = (
+ {} if cache_history is None else json.loads(metadata["tokenizer_config"])
+ )
+ if args.trust_remote_code:
+ tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
+ # If no model path is provided then use the one in the kv cache history
+ model_path = args.model
+ if cache_history is not None and model_path is None:
+ model_path = metadata["model"]
+
model, tokenizer = load(
- args.model,
+ model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
@@ -136,6 +171,8 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
+ elif tokenizer.chat_template is None:
+ tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
@@ -145,11 +182,30 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
+
+ # Treat the prompt as a suffix assuming that the prefix is in the
+ # stored kv cache.
+ if cache_history is not None:
+ test_prompt = tokenizer.apply_chat_template(
+ [{"role": "user", "content": ""}],
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ prompt = prompt[test_prompt.index("") :]
else:
prompt = args.prompt
formatter = colorprint_by_t0 if args.colorize else None
+ # Determine the max kv size from the kv cache or passed arguments
+ max_kv_size = args.max_kv_size
+ if max_kv_size is None:
+ max_kv_size = (
+ int(metadata["max_kv_size"])
+ if cache_history is not None
+ else DEFAULT_MAX_KV_SIZE
+ )
+
generate(
model,
tokenizer,
@@ -159,7 +215,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
- max_kv_size=args.max_kv_size,
+ max_kv_size=max_kv_size,
+ cache_history=cache_history,
)
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index 3e84554c..dc19dd05 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -46,6 +46,7 @@ class KVCache:
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+ @property
def state(self):
return self.keys, self.values
@@ -137,6 +138,7 @@ class RotatingKVCache:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
+ @property
def state(self):
return self.keys, self.values
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 44196766..71476df3 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
-from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@@ -126,6 +126,26 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
return logits
+def make_kv_caches(
+ model: nn.Module, max_kv_size: Optional[int] = None
+) -> List[Union[KVCache, RotatingKVCache]]:
+ if hasattr(model, "make_cache"):
+ return model.make_cache()
+
+ kv_heads = (
+ [model.n_kv_heads] * len(model.layers)
+ if isinstance(model.n_kv_heads, int)
+ else model.n_kv_heads
+ )
+ if max_kv_size is not None:
+ return [
+ RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
+ for n in kv_heads
+ ]
+ else:
+ return [KVCache(model.head_dim, n) for n in kv_heads]
+
+
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -138,6 +158,7 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
+ cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -194,21 +215,19 @@ def generate_step(
)
y = prompt
- if hasattr(model, "make_cache"):
- cache = model.make_cache()
- else:
- kv_heads = (
- [model.n_kv_heads] * len(model.layers)
- if isinstance(model.n_kv_heads, int)
- else model.n_kv_heads
- )
- if max_kv_size is not None:
- cache = [
- RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
- for n in kv_heads
- ]
- else:
- cache = [KVCache(model.head_dim, n) for n in kv_heads]
+
+ # Create the KV cache for generation
+ cache = make_kv_caches(model, max_kv_size)
+
+ if cache_history is not None:
+ if len(cache_history) != len(cache):
+ raise ValueError("Wrong number of layers in the cache history")
+
+ # Set the history in the cache objects and evaluate them to prepare for
+ # generation.
+ for c, h in zip(cache, cache_history):
+ c.update_and_fetch(h[0], h[1])
+ mx.eval([c.state for c in cache])
repetition_context = prompt.tolist()
diff --git a/llms/setup.py b/llms/setup.py
index 88deed17..ac294ae1 100644
--- a/llms/setup.py
+++ b/llms/setup.py
@@ -31,6 +31,7 @@ setup(
},
entry_points={
"console_scripts": [
+ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",
From b1186e2a81c678087bcaab43202d040b126523c3 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Thu, 29 Aug 2024 15:05:17 -0700
Subject: [PATCH 020/188] Docs on prompt scaling (#963)
* docs on prompt scaling
* remove unused var
* nits
---
llms/README.md | 42 ++++++++++++++++++++++++++++++++++---
llms/mlx_lm/cache_prompt.py | 6 +++++-
llms/mlx_lm/generate.py | 11 ++++------
llms/mlx_lm/version.py | 2 +-
4 files changed, 49 insertions(+), 12 deletions(-)
diff --git a/llms/README.md b/llms/README.md
index 497c0277..79f26d41 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -38,7 +38,9 @@ To see a description of all the arguments you can do:
>>> help(generate)
```
-Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
+Check out the [generation
+example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
+to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.
@@ -122,10 +124,44 @@ mlx_lm.convert \
--upload-repo mlx-community/my-4bit-mistral
```
+### Long Prompts and Generations
+
+MLX LM has some tools to scale efficiently to long prompts and generations:
+
+- A rotating fixed-size key-value cache.
+- Prompt caching
+
+To use the rotating key-value cache pass the argument `--max-kv-size n` where
+`n` can be any integer. Smaller values like `512` will use very little RAM but
+result in worse quality. Larger values like `4096` or higher will use more RAM
+but have better quality.
+
+Caching prompts can substantially speedup reusing the same long context with
+different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
+
+```bash
+cat prompt.txt | mlx_lm.cache_prompt \
+ --model mistralai/Mistral-7B-Instruct-v0.3 \
+ --prompt - \
+ --kv-cache-file mistral_prompt.safetensors
+```
+
+Then use the cached prompt with `mlx_lm.generate`:
+
+```
+mlx_lm.generate \
+ --kv-cache-file mistral_prompt.safetensors \
+ --prompt "\nSummarize the above text."
+```
+
+The cached prompt is treated as a prefix to the supplied prompt. Also notice
+when using a cached prompt, the model to use is read from the cache and need
+not be supplied explicitly.
+
### Supported Models
-The example supports Hugging Face format Mistral, Llama, and Phi-2 style
-models. If the model you want to run is not supported, file an
+MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
+run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py
index ad045f1a..fe088118 100644
--- a/llms/mlx_lm/cache_prompt.py
+++ b/llms/mlx_lm/cache_prompt.py
@@ -56,7 +56,7 @@ def setup_arg_parser():
parser.add_argument(
"--max-kv-size",
type=int,
- default=1024,
+ default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
@@ -147,3 +147,7 @@ def main():
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 4aa4001a..54f6f4d2 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -12,7 +12,6 @@ DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
-DEFAULT_MAX_KV_SIZE = 1024
def setup_arg_parser():
@@ -81,6 +80,7 @@ def setup_arg_parser():
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
+ default=None,
)
parser.add_argument(
"--kv-cache-file",
@@ -199,12 +199,9 @@ def main():
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
- if max_kv_size is None:
- max_kv_size = (
- int(metadata["max_kv_size"])
- if cache_history is not None
- else DEFAULT_MAX_KV_SIZE
- )
+ if cache_history is not None:
+ max_kv_size = metadata["max_kv_size"]
+ max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
generate(
model,
diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py
index 41237905..87e86846 100644
--- a/llms/mlx_lm/version.py
+++ b/llms/mlx_lm/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.17.1"
+__version__ = "0.18.0"
From fc93c557238e9441835afe2748fd170016cb068b Mon Sep 17 00:00:00 2001
From: L
Date: Thu, 29 Aug 2024 21:08:57 -0700
Subject: [PATCH 021/188] feat(mlx_lm): Nemotron (#949)
* feat: Nemotron
https://huggingface.co/nvidia/Minitron-4B-Base
This is basically Llama with partial RoPE and LayerNorm instead of
BatchNorm. Also they add 1 to the LayerNorm weight for some reason.
* fixup! feat: Nemotron
* nits
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/models/nemotron.py | 227 +++++++++++++++++++++++++++++++++
llms/mlx_lm/tuner/utils.py | 1 +
2 files changed, 228 insertions(+)
create mode 100644 llms/mlx_lm/models/nemotron.py
diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py
new file mode 100644
index 00000000..ef55d1d7
--- /dev/null
+++ b/llms/mlx_lm/models/nemotron.py
@@ -0,0 +1,227 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Dict, Optional, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .base import BaseModelArgs, KVCache, create_attention_mask
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str
+ hidden_size: int
+ hidden_act: str
+ num_hidden_layers: int
+ intermediate_size: int
+ num_attention_heads: int
+ norm_eps: float
+ vocab_size: int
+ num_key_value_heads: int
+ head_dim: Optional[int] = None
+ max_position_embeddings: Optional[int] = None
+ attention_bias: bool = False
+ mlp_bias: bool = False
+ partial_rotary_factor: float = 0.5
+ rope_theta: float = 10000.0
+ rope_traditional: bool = False
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+ tie_word_embeddings: bool = False
+
+ def __post_init__(self):
+ if self.rope_scaling:
+ if not "factor" in self.rope_scaling:
+ raise ValueError(f"rope_scaling must contain 'factor'")
+ rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
+ "rope_type"
+ )
+ if rope_type is None:
+ raise ValueError(
+ f"rope_scaling must contain either 'type' or 'rope_type'"
+ )
+ if rope_type not in ["linear"]:
+ raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+
+
+@partial(mx.compile, shapeless=True)
+def relu_squared(x):
+ return nn.relu(x).square()
+
+
+class NemotronLayerNorm1P(nn.LayerNorm):
+ def __call__(self, x):
+ weight = self.weight + 1 if "weight" in self else None
+ bias = self.bias if "bias" in self else None
+ return mx.fast.layer_norm(x, weight, bias, self.eps)
+
+
+class Attention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+
+ dim = args.hidden_size
+ self.n_heads = n_heads = args.num_attention_heads
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
+
+ self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
+ self.partial_rotary_factor = args.partial_rotary_factor
+
+ self.scale = head_dim**-0.5
+ if hasattr(args, "attention_bias"):
+ attention_bias = args.attention_bias
+ else:
+ attention_bias = False
+
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
+
+ rope_scale = 1.0
+ if args.rope_scaling and args.rope_scaling["type"] == "linear":
+ assert isinstance(args.rope_scaling["factor"], float)
+ rope_scale = 1 / args.rope_scaling["factor"]
+ self.rope = nn.RoPE(
+ int(self.partial_rotary_factor * self.head_dim),
+ base=args.rope_theta,
+ scale=rope_scale,
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ B, L, _ = x.shape
+
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+ # Prepare the queries, keys and values for the attention computation
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+ if cache is not None:
+ queries = self.rope(queries, offset=cache.offset)
+ keys = self.rope(keys, offset=cache.offset)
+ keys, values = cache.update_and_fetch(keys, values)
+ else:
+ queries = self.rope(queries)
+ keys = self.rope(keys)
+
+ output = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+
+class MLP(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+
+ dim = args.hidden_size
+ hidden_dim = args.intermediate_size
+ mlp_bias = args.mlp_bias
+
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
+
+ def __call__(self, x) -> mx.array:
+ return self.down_proj(relu_squared(self.up_proj(x)))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.num_attention_heads = args.num_attention_heads
+ self.hidden_size = args.hidden_size
+ self.self_attn = Attention(args)
+ self.mlp = MLP(args)
+ self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
+ self.post_attention_layernorm = NemotronLayerNorm1P(
+ args.hidden_size, eps=args.norm_eps
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
+ h = x + r
+ r = self.mlp(self.post_attention_layernorm(h))
+ out = h + r
+ return out
+
+
+class NemotronModel(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.vocab_size = args.vocab_size
+ self.num_hidden_layers = args.num_hidden_layers
+ assert self.vocab_size > 0
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [
+ TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
+ ]
+ self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ ):
+ h = self.embed_tokens(inputs)
+
+ mask = create_attention_mask(h, cache)
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for layer, c in zip(self.layers, cache):
+ h = layer(h, mask, cache=c)
+
+ return self.norm(h)
+
+
+class Model(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.model_type = args.model_type
+ self.model = NemotronModel(args)
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ ):
+ out = self.model(inputs, cache)
+ if self.args.tie_word_embeddings:
+ out = self.model.embed_tokens.as_linear(out)
+ else:
+ out = self.lm_head(out)
+ return out
+
+ @property
+ def layers(self):
+ return self.model.layers
+
+ @property
+ def head_dim(self):
+ return (
+ self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
+ )
+
+ @property
+ def n_kv_heads(self):
+ return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index 1a54a925..71fbfaab 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -93,6 +93,7 @@ def linear_to_lora_layers(
"llama",
"phi",
"mixtral",
+ "nemotron",
"stablelm",
"qwen2",
"qwen2_moe",
From 3c6e8b11af9dda55ee8d38ee9f612a65118f7793 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Fri, 30 Aug 2024 05:56:27 -0700
Subject: [PATCH 022/188] fix (#965)
---
llms/mlx_lm/generate.py | 2 +-
llms/mlx_lm/version.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 54f6f4d2..f37037b6 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -171,7 +171,7 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
- elif tokenizer.chat_template is None:
+ elif cache_history is not None:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py
index 87e86846..a2eb9a25 100644
--- a/llms/mlx_lm/version.py
+++ b/llms/mlx_lm/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.18.0"
+__version__ = "0.18.1"
From bf921afcbef89463c2b8a78c4008d993500fadb4 Mon Sep 17 00:00:00 2001
From: James Zhao
Date: Tue, 3 Sep 2024 23:16:21 +0300
Subject: [PATCH 023/188] Make sure to import the correct "version" module when
installing mlx_whisper and mlx_lm from local source code. (#969)
* Make sure to import the correct "version" module when installing the
mlx_whisper package from local source code.
* Make sure to import the correct "version" module when installing the mlx_lm package from local source code
* fix
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/__init__.py | 2 +-
llms/mlx_lm/{version.py => _version.py} | 0
llms/setup.py | 2 +-
whisper/mlx_whisper/__init__.py | 2 +-
whisper/mlx_whisper/{version.py => _version.py} | 0
whisper/setup.py | 2 +-
6 files changed, 4 insertions(+), 4 deletions(-)
rename llms/mlx_lm/{version.py => _version.py} (100%)
rename whisper/mlx_whisper/{version.py => _version.py} (100%)
diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py
index e971c467..502c78e5 100644
--- a/llms/mlx_lm/__init__.py
+++ b/llms/mlx_lm/__init__.py
@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
+from ._version import __version__
from .utils import convert, generate, load, stream_generate
-from .version import __version__
diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/_version.py
similarity index 100%
rename from llms/mlx_lm/version.py
rename to llms/mlx_lm/_version.py
diff --git a/llms/setup.py b/llms/setup.py
index ac294ae1..e2cfe0cd 100644
--- a/llms/setup.py
+++ b/llms/setup.py
@@ -10,7 +10,7 @@ with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
-from version import __version__
+from _version import __version__
setup(
name="mlx-lm",
diff --git a/whisper/mlx_whisper/__init__.py b/whisper/mlx_whisper/__init__.py
index e6de0858..14c5197f 100644
--- a/whisper/mlx_whisper/__init__.py
+++ b/whisper/mlx_whisper/__init__.py
@@ -1,5 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from . import audio, decoding, load_models
+from ._version import __version__
from .transcribe import transcribe
-from .version import __version__
diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/_version.py
similarity index 100%
rename from whisper/mlx_whisper/version.py
rename to whisper/mlx_whisper/_version.py
diff --git a/whisper/setup.py b/whisper/setup.py
index 086f6471..0cabd64b 100644
--- a/whisper/setup.py
+++ b/whisper/setup.py
@@ -12,7 +12,7 @@ with open(package_dir / "requirements.txt") as fid:
sys.path.append(str(package_dir))
-from version import __version__
+from _version import __version__
setup(
name="mlx-whisper",
From 83a209e200eef505a688997f64afdff2c6e0d363 Mon Sep 17 00:00:00 2001
From: Chime Ogbuji
Date: Tue, 3 Sep 2024 16:29:10 -0400
Subject: [PATCH 024/188] Add prompt piping (#962)
* Initial commit of --prompt-only and prompt from STDIN feature
* Switch to using --verbose instead of --prompt-only
* Fix capitalization typo
* Fix reference to changed option name
* Update exception text
---
llms/mlx_lm/generate.py | 30 ++++++++++++++++++++++++++----
1 file changed, 26 insertions(+), 4 deletions(-)
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index f37037b6..537bd853 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -2,6 +2,7 @@
import argparse
import json
+import sys
import mlx.core as mx
@@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
+def str2bool(string):
+ return string.lower() not in ["false", "f"]
+
+
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="LLM inference script")
@@ -39,7 +44,9 @@ def setup_arg_parser():
help="End of sequence token for tokenizer",
)
parser.add_argument(
- "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
+ "--prompt",
+ default=DEFAULT_PROMPT,
+ help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
@@ -65,6 +72,12 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
+ parser.add_argument(
+ "--verbose",
+ type=str2bool,
+ default=True,
+ help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
+ )
parser.add_argument(
"--colorize",
action="store_true",
@@ -178,7 +191,12 @@ def main():
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
- messages = [{"role": "user", "content": args.prompt}]
+ messages = [
+ {
+ "role": "user",
+ "content": sys.stdin.read() if args.prompt == "-" else args.prompt,
+ }
+ ]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@@ -195,6 +213,8 @@ def main():
else:
prompt = args.prompt
+ if args.colorize and not args.verbose:
+ raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
@@ -203,18 +223,20 @@ def main():
max_kv_size = metadata["max_kv_size"]
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
- generate(
+ response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
- verbose=True,
+ verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=max_kv_size,
cache_history=cache_history,
)
+ if not args.verbose:
+ print(response)
if __name__ == "__main__":
From bd29aec299c8fa59c161a9c1207bfc59db31d845 Mon Sep 17 00:00:00 2001
From: madroid
Date: Wed, 4 Sep 2024 21:19:32 +0800
Subject: [PATCH 025/188] Support HuggingFace model tree (#957)
* Hub: Update quantization configuration fields
* Hub: add base_model metadata
* Hub: add quantization_config for model tree Quantized type
* Hub: update quantization_config value
* Hub: remove config print
---
llms/mlx_lm/utils.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 71476df3..eee28c9c 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -560,6 +560,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
+ card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
@@ -666,6 +667,8 @@ def quantize_model(
quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
+ # support hf model tree #957
+ quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
From 324184d670ec11916a5e92314171d497b312eefe Mon Sep 17 00:00:00 2001
From: Angelos Katharopoulos
Date: Fri, 6 Sep 2024 20:19:27 -0700
Subject: [PATCH 026/188] Fix the cache_prompt (#979)
---
llms/mlx_lm/cache_prompt.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py
index fe088118..9829efb4 100644
--- a/llms/mlx_lm/cache_prompt.py
+++ b/llms/mlx_lm/cache_prompt.py
@@ -139,8 +139,8 @@ def main():
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
- cache_dict[f"{i}_keys"] = c.state[0]
- cache_dict[f"{i}_values"] = c.state[1]
+ cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
+ cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
From c3e3411756e098a3f5f29d988e9221034f1af47c Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Sat, 7 Sep 2024 06:06:15 -0700
Subject: [PATCH 027/188] Update LLM generation docs to use chat template
(#973)
* fix docs
* add template to model cards as well
* revert
* version
---
llms/README.md | 14 +++++++++++++-
llms/mlx_lm/_version.py | 2 +-
llms/mlx_lm/utils.py | 11 ++++++++++-
3 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/llms/README.md b/llms/README.md
index 79f26d41..b8e1914d 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -29,7 +29,14 @@ from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
-response = generate(model, tokenizer, prompt="hello", verbose=True)
+prompt = "Write a story about Einstein"
+
+messages = [{"role": "user", "content": prompt}]
+prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+)
+
+response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
To see a description of all the arguments you can do:
@@ -79,6 +86,11 @@ model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
+messages = [{"role": "user", "content": prompt}]
+prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+)
+
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py
index a2eb9a25..8110c823 100644
--- a/llms/mlx_lm/_version.py
+++ b/llms/mlx_lm/_version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.18.1"
+__version__ = "0.18.2"
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index eee28c9c..ad9b3221 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -577,7 +577,16 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
- response = generate(model, tokenizer, prompt="hello", verbose=True)
+
+ prompt="hello"
+
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
+ messages = [{"role": "user", "content": prompt}]
+ prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
From 6c2369e4b97f49fb5906ec46033497b39931b25d Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Sat, 7 Sep 2024 14:46:57 -0700
Subject: [PATCH 028/188] Fix bug in upload + docs nit (#981)
* fix bug in upload + docs nit
* nit
---
llms/mlx_lm/LORA.md | 30 +++++++-----------------------
llms/mlx_lm/utils.py | 2 +-
2 files changed, 8 insertions(+), 24 deletions(-)
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
index 2e739d0f..2d9a2553 100644
--- a/llms/mlx_lm/LORA.md
+++ b/llms/mlx_lm/LORA.md
@@ -166,44 +166,28 @@ Currently, `*.jsonl` files support three data formats: `chat`,
`chat`:
```jsonl
-{
- "messages": [
- {
- "role": "system",
- "content": "You are a helpful assistant."
- },
- {
- "role": "user",
- "content": "Hello."
- },
- {
- "role": "assistant",
- "content": "How can I assistant you today."
- }
- ]
-}
+{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
`completions`:
```jsonl
-{
- "prompt": "What is the capital of France?",
- "completion": "Paris."
-}
+{"prompt": "What is the capital of France?", "completion": "Paris."}
```
`text`:
```jsonl
-{
- "text": "This is an example for the model."
-}
+{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
+> [!NOTE]
+> Each example in the datasets must be on a single line. Do not put more than
+> one example per line and do not split an example accross multiple lines.
+
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index ad9b3221..b4a2ea51 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -581,7 +581,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
- messages = [{"role": "user", "content": prompt}]
+ messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
From f530f56df2738a54982c4541189a8c8d7cd94c44 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Tue, 17 Sep 2024 16:22:48 -0700
Subject: [PATCH 029/188] don't use internal exception (#990)
---
llms/mlx_lm/utils.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index b4a2ea51..5621609d 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -14,7 +14,6 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
-from huggingface_hub.utils._errors import RepositoryNotFoundError
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
@@ -91,7 +90,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
],
)
)
- except RepositoryNotFoundError:
+ except:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
"Please make sure you specified the local path or Hugging Face"
From 796d5e40e4cce0e0d49d3b3b3c00957b31702fe0 Mon Sep 17 00:00:00 2001
From: Angelos Katharopoulos
Date: Fri, 20 Sep 2024 13:33:45 -0700
Subject: [PATCH 030/188] Fix export to gguf (#993)
---
llms/mlx_lm/gguf.py | 27 ++++++++++++++-------------
1 file changed, 14 insertions(+), 13 deletions(-)
diff --git a/llms/mlx_lm/gguf.py b/llms/mlx_lm/gguf.py
index 5d524580..241ac35a 100644
--- a/llms/mlx_lm/gguf.py
+++ b/llms/mlx_lm/gguf.py
@@ -67,7 +67,7 @@ class HfVocab:
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
- if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
+ if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
@@ -77,9 +77,7 @@ class HfVocab:
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
- toktype = self.get_token_type(
- self.specials[text], b"", self.special_ids
- )
+ toktype = self.get_token_type(self.specials[text], "", self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
@@ -243,15 +241,18 @@ def prepare_metadata(config, vocab):
metadata["tokenizer.ggml.tokens"] = tokens
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
- metadata["tokenizer.ggml.bos_token_id"] = mx.array(
- vocab.tokenizer.bos_token_id, dtype=mx.uint32
- )
- metadata["tokenizer.ggml.eos_token_id"] = mx.array(
- vocab.tokenizer.eos_token_id, dtype=mx.uint32
- )
- metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
- vocab.tokenizer.unk_token_id, dtype=mx.uint32
- )
+ if vocab.tokenizer.bos_token_id is not None:
+ metadata["tokenizer.ggml.bos_token_id"] = mx.array(
+ vocab.tokenizer.bos_token_id, dtype=mx.uint32
+ )
+ if vocab.tokenizer.eos_token_id is not None:
+ metadata["tokenizer.ggml.eos_token_id"] = mx.array(
+ vocab.tokenizer.eos_token_id, dtype=mx.uint32
+ )
+ if vocab.tokenizer.unk_token_id is not None:
+ metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
+ vocab.tokenizer.unk_token_id, dtype=mx.uint32
+ )
metadata = {k: v for k, v in metadata.items() if v is not None}
return metadata
From 9bb2dd62f350d9f0f1a8003e354431b5172831e5 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Mon, 23 Sep 2024 11:39:25 -0700
Subject: [PATCH 031/188] Encodec (#991)
* initial encodec
* works
* nits
* use fast group norm
* fix for rnn layer
* fix mlx version
* use custom LSTM kernel
* audio encodec
* fix example, support batched inference
* nits
---
README.md | 1 +
encodec/README.md | 83 ++++
encodec/benchmarks/bench_mx.py | 30 ++
encodec/benchmarks/bench_pt.py | 34 ++
encodec/convert.py | 213 +++++++++++
encodec/encodec.py | 671 +++++++++++++++++++++++++++++++++
encodec/example.py | 37 ++
encodec/requirements.txt | 3 +
encodec/test.py | 66 ++++
encodec/utils.py | 129 +++++++
10 files changed, 1267 insertions(+)
create mode 100644 encodec/README.md
create mode 100644 encodec/benchmarks/bench_mx.py
create mode 100644 encodec/benchmarks/bench_pt.py
create mode 100644 encodec/convert.py
create mode 100644 encodec/encodec.py
create mode 100644 encodec/example.py
create mode 100644 encodec/requirements.txt
create mode 100644 encodec/test.py
create mode 100644 encodec/utils.py
diff --git a/README.md b/README.md
index 00e57803..bd180975 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,7 @@ Some more useful examples are listed below.
### Audio Models
- Speech recognition with [OpenAI's Whisper](whisper).
+- Audio compression and generation with [Meta's EnCodec](encodec).
### Multimodal models
diff --git a/encodec/README.md b/encodec/README.md
new file mode 100644
index 00000000..3ab2793c
--- /dev/null
+++ b/encodec/README.md
@@ -0,0 +1,83 @@
+# EnCodec
+
+An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
+generate audio.
+
+### Setup
+
+Install the requirements:
+
+```
+pip install -r requirements.txt
+```
+
+Optionally install FFmpeg and SciPy for loading and saving audio files,
+respectively.
+
+Install [FFmpeg](https://ffmpeg.org/):
+
+```
+# on macOS using Homebrew (https://brew.sh/)
+brew install ffmpeg
+```
+
+Install SciPy:
+
+```
+pip install scipy
+```
+
+### Example
+
+An example using the model:
+
+```python
+import mlx.core as mx
+from utils import load, load_audio, save_audio
+
+# Load the 48 KHz model and preprocessor.
+model, processor = load("mlx-community/encodec-48khz-float32")
+
+# Load an audio file
+audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
+
+# Preprocess the audio (this can also be a list of arrays for batched
+# processing).
+feats, mask = processor(audio)
+
+# Encode at the given bandwidth. A lower bandwidth results in more
+# compression but lower reconstruction quality.
+@mx.compile
+def encode(feats, mask):
+ return model.encode(feats, mask, bandwidth=3)
+
+# Decode to reconstruct the audio
+@mx.compile
+def decode(codes, scales, mask):
+ return model.decode(codes, scales, mask)
+
+
+codes, scales = encode(feats, mask)
+reconstructed = decode(codes, scales, mask)
+
+# Trim any padding:
+reconstructed = reconstructed[0, : len(audio)]
+
+# Save the audio as a wave file
+save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
+```
+
+The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
+[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
+in several data types.
+
+### Optional
+
+To convert models, use the `convert.py` script. To see the options, run:
+
+```bash
+python convert.py -h
+```
+
+[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
+ [code](https://github.com/facebookresearch/encodec) for more details.
diff --git a/encodec/benchmarks/bench_mx.py b/encodec/benchmarks/bench_mx.py
new file mode 100644
index 00000000..2acd4b75
--- /dev/null
+++ b/encodec/benchmarks/bench_mx.py
@@ -0,0 +1,30 @@
+# Copyright © 2024 Apple Inc.
+
+import time
+
+import mlx.core as mx
+from utils import load
+
+model, processor = load("mlx-community/encodec-48khz-float32")
+
+audio = mx.random.uniform(shape=(288000, 2))
+feats, mask = processor(audio)
+mx.eval(model, feats, mask)
+
+
+@mx.compile
+def fun():
+ codes, scales = model.encode(feats, mask, bandwidth=3)
+ reconstructed = model.decode(codes, scales, mask)
+ return reconstructed
+
+
+for _ in range(5):
+ mx.eval(fun())
+
+tic = time.time()
+for _ in range(10):
+ mx.eval(fun())
+toc = time.time()
+ms = 1000 * (toc - tic) / 10
+print(f"Time per it: {ms:.3f}")
diff --git a/encodec/benchmarks/bench_pt.py b/encodec/benchmarks/bench_pt.py
new file mode 100644
index 00000000..5d158a32
--- /dev/null
+++ b/encodec/benchmarks/bench_pt.py
@@ -0,0 +1,34 @@
+# Copyright © 2024 Apple Inc.
+
+import time
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, EncodecModel
+
+processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
+audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
+
+pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
+pt_inputs = processor(
+ raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
+).to("mps")
+
+
+def fun():
+ pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
+ pt_audio = pt_model.decode(
+ pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
+ )
+ torch.mps.synchronize()
+
+
+for _ in range(5):
+ fun()
+
+tic = time.time()
+for _ in range(10):
+ fun()
+toc = time.time()
+ms = 1000 * (toc - tic) / 10
+print(f"Time per it: {ms:.3f}")
diff --git a/encodec/convert.py b/encodec/convert.py
new file mode 100644
index 00000000..13bd31a6
--- /dev/null
+++ b/encodec/convert.py
@@ -0,0 +1,213 @@
+# Copyright © 2024 Apple Inc.
+
+import argparse
+import json
+from pathlib import Path
+from textwrap import dedent
+from types import SimpleNamespace
+from typing import Any, Dict, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+from huggingface_hub import snapshot_download
+from mlx.utils import tree_flatten
+
+import encodec
+
+
+def fetch_from_hub(hf_repo: str) -> Path:
+ model_path = Path(
+ snapshot_download(
+ repo_id=hf_repo,
+ allow_patterns=["*.json", "*.safetensors"],
+ )
+ )
+ return model_path
+
+
+def upload_to_hub(path: str, upload_repo: str, hf_path: str):
+ """
+ Uploads the model to Hugging Face hub.
+
+ Args:
+ path (str): Local path to the model.
+ upload_repo (str): Name of the HF repo to upload to.
+ hf_path (str): Path to the original Hugging Face model.
+ """
+ import os
+
+ from huggingface_hub import HfApi, ModelCard, logging
+
+ content = dedent(
+ f"""
+ ---
+ language: en
+ license: other
+ library: mlx
+ tags:
+ - mlx
+ ---
+
+ The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
+ converted to MLX format from
+ [{hf_path}](https://huggingface.co/{hf_path}).
+
+ This model is intended to be used with the [EnCodec MLX
+ example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
+ """
+ )
+
+ card = ModelCard(content)
+ card.save(os.path.join(path, "README.md"))
+
+ logging.set_verbosity_info()
+
+ api = HfApi()
+ api.create_repo(repo_id=upload_repo, exist_ok=True)
+ api.upload_folder(
+ folder_path=path,
+ repo_id=upload_repo,
+ repo_type="model",
+ multi_commits=True,
+ multi_commits_verbose=True,
+ )
+ print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
+
+
+def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
+ if isinstance(save_path, str):
+ save_path = Path(save_path)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ total_size = sum(v.nbytes for v in weights.values())
+ index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
+ mx.save_safetensors(
+ str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
+ )
+
+ for weight_name in weights.keys():
+ index_data["weight_map"][weight_name] = "model.safetensors"
+
+ index_data["weight_map"] = {
+ k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
+ }
+
+ with open(save_path / "model.safetensors.index.json", "w") as f:
+ json.dump(index_data, f, indent=4)
+
+
+def save_config(
+ config: dict,
+ config_path: Union[str, Path],
+) -> None:
+ """Save the model configuration to the ``config_path``.
+
+ The final configuration will be sorted before saving for better readability.
+
+ Args:
+ config (dict): The model configuration.
+ config_path (Union[str, Path]): Model configuration file path.
+ """
+ # Clean unused keys
+ config.pop("_name_or_path", None)
+
+ # sort the config for better readability
+ config = dict(sorted(config.items()))
+
+ # write the updated config to the config_path (if provided)
+ with open(config_path, "w") as fid:
+ json.dump(config, fid, indent=4)
+
+
+def convert(
+ upload: bool,
+ model: str,
+ dtype: str = None,
+):
+ hf_repo = f"facebook/encodec_{model}"
+ mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
+ path = fetch_from_hub(hf_repo)
+ save_path = Path("mlx_models")
+
+ weights = mx.load(str(Path(path) / "model.safetensors"))
+
+ with open(path / "config.json", "r") as fid:
+ config = SimpleNamespace(**json.load(fid))
+
+ model = encodec.EncodecModel(config)
+
+ new_weights = {}
+ for k, v in weights.items():
+ basename, pname = k.rsplit(".", 1)
+ if pname == "weight_v":
+ g = weights[basename + ".weight_g"]
+ v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
+ k = basename + ".weight"
+ elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
+ continue
+ elif "lstm" in basename:
+ w_or_b, ih_or_hh, ln = pname.split("_")
+ if w_or_b == "weight":
+ new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
+ elif w_or_b == "bias" and ih_or_hh == "ih":
+ continue
+ else:
+ v = v + weights[k.replace("_hh_", "_ih_")]
+ new_pname = "bias"
+ k = basename + "." + ln[1:] + "." + new_pname
+ if "conv.weight" in k:
+ # Possibly a transposed conv which has a different order
+ if "decoder" in k:
+ ln = int(k.split(".")[2])
+ if "conv" in model.decoder.layers[ln] and isinstance(
+ model.decoder.layers[ln].conv, nn.ConvTranspose1d
+ ):
+ v = mx.moveaxis(v, 0, 2)
+ else:
+ v = mx.moveaxis(v, 1, 2)
+ else:
+ v = mx.moveaxis(v, 1, 2)
+
+ new_weights[k] = v
+ weights = new_weights
+
+ model.load_weights(list(weights.items()))
+
+ if dtype is not None:
+ t = getattr(mx, dtype)
+ weights = {k: v.astype(t) for k, v in weights.items()}
+
+ if isinstance(save_path, str):
+ save_path = Path(save_path)
+
+ save_weights(save_path, weights)
+
+ save_config(vars(config), config_path=save_path / "config.json")
+
+ if upload:
+ upload_to_hub(save_path, mlx_repo, hf_repo)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="48khz",
+ help="",
+ choices=["24khz", "32khz", "48khz"],
+ )
+ parser.add_argument(
+ "--upload",
+ action="store_true",
+ help="Upload the weights to Hugging Face.",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ help="Data type to convert the model to.",
+ default="float32",
+ choices=["float32", "bfloat16", "float16"],
+ )
+ args = parser.parse_args()
+ convert(upload=args.upload, model=args.model, dtype=args.dtype)
diff --git a/encodec/encodec.py b/encodec/encodec.py
new file mode 100644
index 00000000..3ef47369
--- /dev/null
+++ b/encodec/encodec.py
@@ -0,0 +1,671 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+import numpy as np
+
+_lstm_kernel = mx.fast.metal_kernel(
+ name="lstm",
+ input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
+ output_names=["hidden_state", "cell_state"],
+ header="""
+ template
+ T sigmoid(T x) {
+ auto y = 1 / (1 + metal::exp(-metal::abs(x)));
+ return (x < 0) ? 1 - y : y;
+ }
+ """,
+ source="""
+ uint b = thread_position_in_grid.x;
+ uint d = hidden_size * 4;
+
+ uint elem = b * d + thread_position_in_grid.y;
+ uint index = elem;
+ uint x_index = b * num_time_steps * d + time_step * d + index;
+
+ auto i = sigmoid(h_in[index] + x[x_index]);
+ index += hidden_size;
+ x_index += hidden_size;
+ auto f = sigmoid(h_in[index] + x[x_index]);
+ index += hidden_size;
+ x_index += hidden_size;
+ auto g = metal::precise::tanh(h_in[index] + x[x_index]);
+ index += hidden_size;
+ x_index += hidden_size;
+ auto o = sigmoid(h_in[index] + x[x_index]);
+
+ cell_state[elem] = f * cell[elem] + i * g;
+ hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
+ """,
+)
+
+
+def lstm_custom(x, h_in, cell, time_step):
+ assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
+ out_shape = cell.shape
+ return _lstm_kernel(
+ inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
+ output_shapes=[out_shape, out_shape],
+ output_dtypes=[h_in.dtype, h_in.dtype],
+ grid=(x.shape[0], h_in.size // 4, 1),
+ threadgroup=(256, 1, 1),
+ )
+
+
+class LSTM(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ bias: bool = True,
+ ):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.Wx = mx.zeros((4 * hidden_size, input_size))
+ self.Wh = mx.zeros((4 * hidden_size, hidden_size))
+ self.bias = mx.zeros((4 * hidden_size,)) if bias else None
+
+ def __call__(self, x, hidden=None, cell=None):
+ if self.bias is not None:
+ x = mx.addmm(self.bias, x, self.Wx.T)
+ else:
+ x = x @ self.Wx.T
+
+ all_hidden = []
+
+ B = x.shape[0]
+ cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
+ for t in range(x.shape[-2]):
+ if hidden is None:
+ hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
+ else:
+ hidden = hidden @ self.Wh.T
+ hidden, cell = lstm_custom(x, hidden, cell, t)
+ all_hidden.append(hidden)
+
+ return mx.stack(all_hidden, axis=-2)
+
+
+class EncodecConv1d(nn.Module):
+ """Conv1d with asymmetric or causal padding and normalization."""
+
+ def __init__(
+ self,
+ config,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ ):
+ super().__init__()
+ self.causal = config.use_causal_conv
+ self.pad_mode = config.pad_mode
+ self.norm_type = config.norm_type
+
+ self.conv = nn.Conv1d(
+ in_channels, out_channels, kernel_size, stride, dilation=dilation
+ )
+ if self.norm_type == "time_group_norm":
+ self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
+
+ self.stride = stride
+
+ # Effective kernel size with dilations.
+ self.kernel_size = (kernel_size - 1) * dilation + 1
+
+ self.padding_total = kernel_size - stride
+
+ def _get_extra_padding_for_conv1d(
+ self,
+ hidden_states: mx.array,
+ ) -> mx.array:
+ length = hidden_states.shape[1]
+ n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
+ n_frames = int(math.ceil(n_frames)) - 1
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
+ return ideal_length - length
+
+ def _pad1d(
+ self,
+ hidden_states: mx.array,
+ paddings: Tuple[int, int],
+ mode: str = "zero",
+ value: float = 0.0,
+ ):
+ if mode != "reflect":
+ return mx.pad(
+ hidden_states, paddings, mode="constant", constant_values=value
+ )
+
+ length = hidden_states.shape[1]
+ prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
+ suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
+ return mx.concatenate([prefix, hidden_states, suffix], axis=1)
+
+ def __call__(self, hidden_states):
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
+
+ if self.causal:
+ # Left padding for causal
+ hidden_states = self._pad1d(
+ hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
+ )
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = self.padding_total // 2
+ padding_left = self.padding_total - padding_right
+ hidden_states = self._pad1d(
+ hidden_states,
+ (padding_left, padding_right + extra_padding),
+ mode=self.pad_mode,
+ )
+
+ hidden_states = self.conv(hidden_states)
+
+ if self.norm_type == "time_group_norm":
+ hidden_states = self.norm(hidden_states)
+
+ return hidden_states
+
+
+class EncodecConvTranspose1d(nn.Module):
+ """ConvTranspose1d with asymmetric or causal padding and normalization."""
+
+ def __init__(
+ self,
+ config,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ ):
+ super().__init__()
+ self.causal = config.use_causal_conv
+ self.trim_right_ratio = config.trim_right_ratio
+ self.norm_type = config.norm_type
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
+ if config.norm_type == "time_group_norm":
+ self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
+ self.padding_total = kernel_size - stride
+
+ def __call__(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ if self.norm_type == "time_group_norm":
+ hidden_states = self.norm(hidden_states)
+
+ if self.causal:
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
+ else:
+ padding_right = self.padding_total // 2
+
+ padding_left = self.padding_total - padding_right
+
+ end = hidden_states.shape[1] - padding_right
+ hidden_states = hidden_states[:, padding_left:end, :]
+ return hidden_states
+
+
+class EncodecLSTM(nn.Module):
+ def __init__(self, config, dimension):
+ super().__init__()
+ self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
+
+ def __call__(self, hidden_states):
+ h = hidden_states
+ for lstm in self.lstm:
+ h = lstm(h)
+ return h + hidden_states
+
+
+class EncodecResnetBlock(nn.Module):
+ """
+ Residual block from SEANet model as used by EnCodec.
+ """
+
+ def __init__(self, config, dim: int, dilations: List[int]):
+ super().__init__()
+ kernel_sizes = (config.residual_kernel_size, 1)
+ if len(kernel_sizes) != len(dilations):
+ raise ValueError("Number of kernel sizes should match number of dilations")
+
+ hidden = dim // config.compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [nn.ELU()]
+ block += [
+ EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
+ ]
+ self.block = block
+
+ if getattr(config, "use_conv_shortcut", True):
+ self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
+ else:
+ self.shortcut = nn.Identity()
+
+ def __call__(self, hidden_states):
+ residual = hidden_states
+ for layer in self.block:
+ hidden_states = layer(hidden_states)
+
+ return self.shortcut(residual) + hidden_states
+
+
+class EncodecEncoder(nn.Module):
+ """SEANet encoder as used by EnCodec."""
+
+ def __init__(self, config):
+ super().__init__()
+ model = [
+ EncodecConv1d(
+ config, config.audio_channels, config.num_filters, config.kernel_size
+ )
+ ]
+ scaling = 1
+
+ for ratio in reversed(config.upsampling_ratios):
+ current_scale = scaling * config.num_filters
+ for j in range(config.num_residual_layers):
+ model += [
+ EncodecResnetBlock(
+ config, current_scale, [config.dilation_growth_rate**j, 1]
+ )
+ ]
+ model += [nn.ELU()]
+ model += [
+ EncodecConv1d(
+ config,
+ current_scale,
+ current_scale * 2,
+ kernel_size=ratio * 2,
+ stride=ratio,
+ )
+ ]
+ scaling *= 2
+
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
+ model += [nn.ELU()]
+ model += [
+ EncodecConv1d(
+ config,
+ scaling * config.num_filters,
+ config.hidden_size,
+ config.last_kernel_size,
+ )
+ ]
+
+ self.layers = model
+
+ def __call__(self, hidden_states):
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class EncodecDecoder(nn.Module):
+ """SEANet decoder as used by EnCodec."""
+
+ def __init__(self, config):
+ super().__init__()
+ scaling = int(2 ** len(config.upsampling_ratios))
+ model = [
+ EncodecConv1d(
+ config,
+ config.hidden_size,
+ scaling * config.num_filters,
+ config.kernel_size,
+ )
+ ]
+
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
+
+ for ratio in config.upsampling_ratios:
+ current_scale = scaling * config.num_filters
+ model += [nn.ELU()]
+ model += [
+ EncodecConvTranspose1d(
+ config,
+ current_scale,
+ current_scale // 2,
+ kernel_size=ratio * 2,
+ stride=ratio,
+ )
+ ]
+ for j in range(config.num_residual_layers):
+ model += [
+ EncodecResnetBlock(
+ config, current_scale // 2, (config.dilation_growth_rate**j, 1)
+ )
+ ]
+ scaling //= 2
+
+ model += [nn.ELU()]
+ model += [
+ EncodecConv1d(
+ config,
+ config.num_filters,
+ config.audio_channels,
+ config.last_kernel_size,
+ )
+ ]
+ self.layers = model
+
+ def __call__(self, hidden_states):
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class EncodecEuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
+
+ def quantize(self, hidden_states):
+ embed = self.embed.T
+ scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
+ dist = -(
+ scaled_states
+ - 2 * hidden_states @ embed
+ + embed.square().sum(axis=0, keepdims=True)
+ )
+ embed_ind = dist.argmax(axis=-1)
+ return embed_ind
+
+ def encode(self, hidden_states):
+ shape = hidden_states.shape
+ hidden_states = hidden_states.reshape((-1, shape[-1]))
+ embed_ind = self.quantize(hidden_states)
+ embed_ind = embed_ind.reshape(*shape[:-1])
+ return embed_ind
+
+ def decode(self, embed_ind):
+ return self.embed[embed_ind]
+
+
+class EncodecVectorQuantization(nn.Module):
+ """
+ Vector quantization implementation. Currently supports only euclidean distance.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.codebook = EncodecEuclideanCodebook(config)
+
+ def encode(self, hidden_states):
+ return self.codebook.encode(hidden_states)
+
+ def decode(self, embed_ind):
+ return self.codebook.decode(embed_ind)
+
+
+class EncodecResidualVectorQuantizer(nn.Module):
+ """Residual Vector Quantizer."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.codebook_size = config.codebook_size
+
+ hop_length = np.prod(config.upsampling_ratios)
+ self.frame_rate = math.ceil(config.sampling_rate / hop_length)
+ self.num_quantizers = int(
+ 1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
+ )
+ self.layers = [
+ EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
+ ]
+
+ def get_num_quantizers_for_bandwidth(
+ self, bandwidth: Optional[float] = None
+ ) -> int:
+ """Return num_quantizers based on specified target bandwidth."""
+ bw_per_q = math.log2(self.codebook_size) * self.frame_rate
+ num_quantizers = self.num_quantizers
+ if bandwidth is not None and bandwidth > 0.0:
+ num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
+ return num_quantizers
+
+ def encode(
+ self, embeddings: mx.array, bandwidth: Optional[float] = None
+ ) -> mx.array:
+ """
+ Encode a given input array with the specified frame rate at the given
+ bandwidth. The RVQ encode method sets the appropriate number of
+ quantizers to use and returns indices for each quantizer.
+ """
+ num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
+ residual = embeddings
+ all_indices = []
+ for layer in self.layers[:num_quantizers]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = mx.stack(all_indices, axis=1)
+ return out_indices
+
+ def decode(self, codes: mx.array) -> mx.array:
+ """Decode the given codes to the quantized representation."""
+ quantized_out = None
+ for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
+ layer = self.layers[i]
+ quantized = layer.decode(indices.squeeze(1))
+ if quantized_out is None:
+ quantized_out = quantized
+ else:
+ quantized_out = quantized + quantized_out
+ return quantized_out
+
+
+class EncodecModel(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.encoder = EncodecEncoder(config)
+ self.decoder = EncodecDecoder(config)
+ self.quantizer = EncodecResidualVectorQuantizer(config)
+
+ def _encode_frame(
+ self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
+ ) -> Tuple[mx.array, Optional[mx.array]]:
+ """
+ Encodes the given input using the underlying VQVAE.
+ """
+ length = input_values.shape[1]
+ duration = length / self.config.sampling_rate
+
+ if (
+ self.config.chunk_length_s is not None
+ and duration > 1e-5 + self.config.chunk_length_s
+ ):
+ raise RuntimeError(
+ f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
+ )
+
+ scale = None
+ if self.config.normalize:
+ # if the padding is non zero
+ input_values = input_values * padding_mask[..., None]
+ mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
+ scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
+ input_values = input_values / scale
+
+ embeddings = self.encoder(input_values)
+ codes = self.quantizer.encode(embeddings, bandwidth)
+ return codes, scale
+
+ def encode(
+ self,
+ input_values: mx.array,
+ padding_mask: mx.array = None,
+ bandwidth: Optional[float] = None,
+ ) -> Tuple[mx.array, Optional[mx.array]]:
+ """
+ Encodes the input audio waveform into discrete codes.
+
+ Args:
+ input_values (mx.array): The input audio waveform with shape
+ ``(batch_size, channels, sequence_length)``.
+ padding_mask (mx.array): Padding mask used to pad the ``input_values``.
+ bandwidth (float, optional): The target bandwidth. Must be one of
+ ``config.target_bandwidths``. If ``None``, uses the smallest
+ possible bandwidth. bandwidth is represented as a thousandth of
+ what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
+
+ Returns:
+ A list of frames containing the discrete encoded codes for the
+ input audio waveform, along with rescaling factors for each chunk
+ when ``config.normalize==True``. Each frame is a tuple ``(codebook,
+ scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
+ frames)``.
+ """
+
+ if bandwidth is None:
+ bandwidth = self.config.target_bandwidths[0]
+ if bandwidth not in self.config.target_bandwidths:
+ raise ValueError(
+ f"This model doesn't support the bandwidth {bandwidth}. "
+ f"Select one of {self.config.target_bandwidths}."
+ )
+
+ _, input_length, channels = input_values.shape
+
+ if channels < 1 or channels > 2:
+ raise ValueError(
+ f"Number of audio channels must be 1 or 2, but got {channels}"
+ )
+
+ chunk_length = self.chunk_length
+ if chunk_length is None:
+ chunk_length = input_length
+ stride = input_length
+ else:
+ stride = self.chunk_stride
+
+ if padding_mask is None:
+ padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
+ encoded_frames = []
+ scales = []
+
+ step = chunk_length - stride
+ if (input_length % stride) != step:
+ raise ValueError(
+ "The input length is not properly padded for batched chunked "
+ "encoding. Make sure to pad the input correctly."
+ )
+
+ for offset in range(0, input_length - step, stride):
+ mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
+ frame = input_values[:, offset : offset + chunk_length]
+ encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
+ encoded_frames.append(encoded_frame)
+ scales.append(scale)
+
+ encoded_frames = mx.stack(encoded_frames)
+
+ return (encoded_frames, scales)
+
+ @staticmethod
+ def _linear_overlap_add(frames: List[mx.array], stride: int):
+ if len(frames) == 0:
+ raise ValueError("`frames` cannot be an empty list.")
+
+ dtype = frames[0].dtype
+ N, frame_length, C = frames[0].shape
+ total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
+
+ time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
+ weight = 0.5 - (time_vec - 0.5).abs()
+
+ weight = weight[:, None]
+ sum_weight = mx.zeros((total_size, 1), dtype=dtype)
+ out = mx.zeros((N, total_size, C), dtype=dtype)
+ offset = 0
+
+ for frame in frames:
+ frame_length = frame.shape[1]
+ out[:, offset : offset + frame_length] += weight[:frame_length] * frame
+ sum_weight[offset : offset + frame_length] += weight[:frame_length]
+ offset += stride
+
+ return out / sum_weight
+
+ def _decode_frame(
+ self, codes: mx.array, scale: Optional[mx.array] = None
+ ) -> mx.array:
+ embeddings = self.quantizer.decode(codes)
+ outputs = self.decoder(embeddings)
+ if scale is not None:
+ outputs = outputs * scale
+ return outputs
+
+ @property
+ def channels(self):
+ return self.config.audio_channels
+
+ @property
+ def sampling_rate(self):
+ return self.config.sampling_rate
+
+ @property
+ def chunk_length(self):
+ if self.config.chunk_length_s is None:
+ return None
+ else:
+ return int(self.config.chunk_length_s * self.config.sampling_rate)
+
+ @property
+ def chunk_stride(self):
+ if self.config.chunk_length_s is None or self.config.overlap is None:
+ return None
+ else:
+ return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
+
+ def decode(
+ self,
+ audio_codes: mx.array,
+ audio_scales: Union[mx.array, List[mx.array]],
+ padding_mask: Optional[mx.array] = None,
+ ) -> Tuple[mx.array, mx.array]:
+ """
+ Decodes the given frames into an output audio waveform.
+
+ Note that the output might be a bit bigger than the input. In that
+ case, any extra steps at the end should be trimmed.
+
+ Args:
+ audio_codes (mx.array): Discret code embeddings of shape
+ ``(batch_size, nb_chunks, chunk_length)``.
+ audio_scales (mx.array): Scaling factor for each input.
+ padding_mask (mx.array): Padding mask.
+ """
+ chunk_length = self.chunk_length
+ if chunk_length is None:
+ if audio_codes.shape[1] != 1:
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
+ audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
+ else:
+ decoded_frames = []
+
+ for frame, scale in zip(audio_codes, audio_scales):
+ frames = self._decode_frame(frame, scale)
+ decoded_frames.append(frames)
+
+ audio_values = self._linear_overlap_add(
+ decoded_frames, self.chunk_stride or 1
+ )
+
+ # truncate based on padding mask
+ if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
+ audio_values = audio_values[:, : padding_mask.shape[1]]
+ return audio_values
diff --git a/encodec/example.py b/encodec/example.py
new file mode 100644
index 00000000..97b311a1
--- /dev/null
+++ b/encodec/example.py
@@ -0,0 +1,37 @@
+# Copyright © 2024 Apple Inc.
+
+import mlx.core as mx
+from utils import load, load_audio, save_audio
+
+# Load the 48 KHz model and preprocessor.
+model, processor = load("mlx-community/encodec-48khz-float32")
+
+# Load an audio file
+audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
+
+# Preprocess the audio (this can also be a list of arrays for batched
+# processing).
+feats, mask = processor(audio)
+
+
+# Encode at the given bandwidth. A lower bandwidth results in more
+# compression but lower reconstruction quality.
+@mx.compile
+def encode(feats, mask):
+ return model.encode(feats, mask, bandwidth=3)
+
+
+# Decode to reconstruct the audio
+@mx.compile
+def decode(codes, scales, mask):
+ return model.decode(codes, scales, mask)
+
+
+codes, scales = encode(feats, mask)
+reconstructed = decode(codes, scales, mask)
+
+# Trim any padding:
+reconstructed = reconstructed[0, : len(audio)]
+
+# Save the audio as a wave file
+save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
diff --git a/encodec/requirements.txt b/encodec/requirements.txt
new file mode 100644
index 00000000..de5cc646
--- /dev/null
+++ b/encodec/requirements.txt
@@ -0,0 +1,3 @@
+mlx>=0.18
+numpy
+huggingface_hub
diff --git a/encodec/test.py b/encodec/test.py
new file mode 100644
index 00000000..ffc23505
--- /dev/null
+++ b/encodec/test.py
@@ -0,0 +1,66 @@
+# Copyright © 2024 Apple Inc.
+
+import mlx.core as mx
+import numpy as np
+import torch
+from datasets import Audio, load_dataset
+from transformers import AutoProcessor, EncodecModel
+from utils import load, load_audio, preprocess_audio
+
+
+def compare_processors():
+ np.random.seed(0)
+ audio_length = 95500
+ audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
+
+ processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
+
+ pt_inputs = processor(
+ raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
+ )
+ mx_inputs = preprocess_audio(
+ mx.array(audio).T,
+ processor.sampling_rate,
+ processor.chunk_length,
+ processor.chunk_stride,
+ )
+
+ assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
+ assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
+
+
+def compare_models():
+ pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
+ mx_model, _ = load("mlx-community/encodec-48khz-float32")
+
+ np.random.seed(0)
+ audio_length = 190560
+ audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
+ mask = np.ones((1, audio_length), dtype=np.int32)
+ pt_encoded = pt_model.encode(
+ torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
+ )
+ mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
+ pt_codes = pt_encoded.audio_codes.numpy()
+ mx_codes = mx_encoded[0]
+ assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
+
+ for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
+ if mx_scale is not None:
+ pt_scale = pt_scale.numpy()
+ assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
+
+ pt_audio = pt_model.decode(
+ pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
+ )
+ pt_audio = pt_audio[0].squeeze().T.detach().numpy()
+ mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
+ mx_audio = mx_audio.squeeze()
+ assert np.allclose(
+ pt_audio, mx_audio, atol=1e-4, rtol=1e-4
+ ), "Decoding audio mismatch"
+
+
+if __name__ == "__main__":
+ compare_processors()
+ compare_models()
diff --git a/encodec/utils.py b/encodec/utils.py
new file mode 100644
index 00000000..18b3f063
--- /dev/null
+++ b/encodec/utils.py
@@ -0,0 +1,129 @@
+# Copyright © 2024 Apple Inc.
+
+import functools
+import json
+from pathlib import Path
+from types import SimpleNamespace
+from typing import List, Optional, Union
+
+import mlx.core as mx
+import numpy as np
+from huggingface_hub import snapshot_download
+
+import encodec
+
+
+def save_audio(file: str, audio: mx.array, sampling_rate: int):
+ """
+ Save audio to a wave (.wav) file.
+ """
+ from scipy.io.wavfile import write
+
+ audio = (audio * 32767).astype(mx.int16)
+ write(file, sampling_rate, np.array(audio))
+
+
+def load_audio(file: str, sampling_rate: int, channels: int):
+ """
+ Read audio into an mx.array, resampling if necessary.
+
+ Args:
+ file (str): The audio file to open.
+ sampling_rate (int): The sample rate to resample the audio at if needed.
+ channels (int): The number of audio channels.
+
+ Returns:
+ An mx.array containing the audio waveform in float32.
+ """
+ from subprocess import CalledProcessError, run
+
+ # This launches a subprocess to decode audio while down-mixing
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ # fmt: off
+ cmd = [
+ "ffmpeg",
+ "-nostdin",
+ "-threads", "0",
+ "-i", file,
+ "-f", "s16le",
+ "-ac", str(channels),
+ "-acodec", "pcm_s16le",
+ "-ar", str(sampling_rate),
+ "-"
+ ]
+ # fmt: on
+ try:
+ out = run(cmd, capture_output=True, check=True).stdout
+ except CalledProcessError as e:
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+
+ out = mx.array(np.frombuffer(out, np.int16))
+ return out.reshape(-1, channels).astype(mx.float32) / 32767.0
+
+
+def preprocess_audio(
+ raw_audio: Union[mx.array, List[mx.array]],
+ sampling_rate: int = 24000,
+ chunk_length: Optional[int] = None,
+ chunk_stride: Optional[int] = None,
+):
+ r"""
+ Prepare inputs for the EnCodec model.
+
+ Args:
+ raw_audio (mx.array or List[mx.array]): The sequence or batch of
+ sequences to be processed.
+ sampling_rate (int): The sampling rate at which the audio waveform
+ should be digitalized.
+ chunk_length (int, optional): The model's chunk length.
+ chunk_stride (int, optional): The model's chunk stride.
+ """
+ if not isinstance(raw_audio, list):
+ raw_audio = [raw_audio]
+
+ raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
+
+ max_length = max(array.shape[0] for array in raw_audio)
+ if chunk_length is not None:
+ max_length += chunk_length - (max_length % chunk_stride)
+
+ inputs = []
+ masks = []
+ for x in raw_audio:
+ length = x.shape[0]
+ mask = mx.ones((length,), dtype=mx.bool_)
+ difference = max_length - length
+ if difference > 0:
+ mask = mx.pad(mask, (0, difference))
+ x = mx.pad(x, ((0, difference), (0, 0)))
+ inputs.append(x)
+ masks.append(mask)
+ return mx.stack(inputs), mx.stack(masks)
+
+
+def load(path_or_repo):
+ """
+ Load the model and audo preprocessor.
+ """
+ path = Path(path_or_repo)
+ if not path.exists():
+ path = Path(
+ snapshot_download(
+ repo_id=path_or_repo,
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
+ )
+ )
+
+ with open(path / "config.json", "r") as f:
+ config = SimpleNamespace(**json.load(f))
+
+ model = encodec.EncodecModel(config)
+ model.load_weights(str(path / "model.safetensors"))
+ processor = functools.partial(
+ preprocess_audio,
+ sampling_rate=config.sampling_rate,
+ chunk_length=model.chunk_length,
+ chunk_stride=model.chunk_stride,
+ )
+ mx.eval(model)
+ return model, processor
From e776c970f708e7a64e192d92cd90f100105a4fd6 Mon Sep 17 00:00:00 2001
From: Cheng
Date: Wed, 25 Sep 2024 23:19:41 +0900
Subject: [PATCH 032/188] Fix llava model when using text-only prompt (#998)
---
llava/llava.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/llava/llava.py b/llava/llava.py
index 06e56059..9e6b7511 100644
--- a/llava/llava.py
+++ b/llava/llava.py
@@ -68,11 +68,10 @@ class LlavaModel(nn.Module):
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
- if pixel_values is None:
- return self.language_model(input_ids)
-
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+ if pixel_values is None:
+ return inputs_embeds
# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(
From 76710f61af401457e505b851ff729d93489e07b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?=
<60228478+Goekdeniz-Guelmez@users.noreply.github.com>
Date: Sat, 28 Sep 2024 16:02:53 +0200
Subject: [PATCH 033/188] Adding support for mamba (#940)
* initial commit
* initial commit
* Adding first lines
* adding x, and dt projection layers
* adding the clamping mechanism
* First succesful inference
* last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub
* clean up
* save up
* almost
* update
* update
* fixed cache handeling
* fixed loading
* added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class
* quick update
* still not working
* save
* still not working
* initial commit
* utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple
* update
* update
* Fixing the Batching Depfwise Comnvolution and multi token input
* fixing generate and logits outputs
* Done!
* Fixing the cache handling, generating works now trying training
* update ACKNOWLEDGEMENTS
* removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils.
* quick clean up
* update trainer/utils for right initialisation of the layers for LoRA, but not working.
* clean up
* Forther update to trainer/utils for correct layer selection. Successfull training
* removing extra mamba-infer.py file
* clean up, reformating will come later
* reformat and big clean up, final commit
* some speedups and cleanups
* fix test
* nits
* nits
---------
Co-authored-by: Awni Hannun
---
ACKNOWLEDGMENTS.md | 1 +
llms/mlx_lm/models/mamba.py | 231 ++++++++++++++++++++++++++++++++++++
llms/mlx_lm/tuner/utils.py | 10 +-
llms/tests/test_models.py | 29 +++--
4 files changed, 263 insertions(+), 8 deletions(-)
create mode 100644 llms/mlx_lm/models/mamba.py
diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index 2b98bc95..2037a076 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
+- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`.
diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py
new file mode 100644
index 00000000..26408426
--- /dev/null
+++ b/llms/mlx_lm/models/mamba.py
@@ -0,0 +1,231 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from dataclasses import dataclass
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .base import BaseModelArgs
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str
+ vocab_size: int
+ hidden_size: int
+ intermediate_size: int
+ state_size: int
+ num_hidden_layers: int
+ conv_kernel: int
+ use_bias: bool
+ use_conv_bias: bool
+ time_step_rank: int
+ tie_word_embeddings: bool = True
+
+ def __post_init__(self):
+ if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
+ self.hidden_size = self.d_model
+ if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
+ self.intermediate_size = self.d_inner
+ if not hasattr(self, "state_size") and hasattr(self, "d_state"):
+ self.state_size = self.d_state
+ if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
+ self.num_hidden_layers = self.n_layer
+ if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
+ self.num_hidden_layers = self.n_layers
+ if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
+ self.conv_kernel = self.d_conv
+ if not hasattr(self, "use_bias") and hasattr(self, "bias"):
+ self.use_bias = self.bias
+ if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
+ self.use_conv_bias = self.conv_bias
+
+ if self.time_step_rank == "auto":
+ self.time_step_rank = math.ceil(self.hidden_size / 16)
+
+
+class MambaCache:
+ def __init__(self):
+ self.cache = [None, None]
+
+ def __setitem__(self, idx, value):
+ self.cache[idx] = value
+
+ def __getitem__(self, idx):
+ return self.cache[idx]
+
+ @property
+ def state(self):
+ return self.cache
+
+
+class DepthWiseConv1d(nn.Module):
+ def __init__(self, channels, kernel_size, bias=True, padding=0):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.padding = padding
+ self.weight = mx.random.normal((self.channels, kernel_size, 1))
+ self.bias = mx.zeros((channels,)) if bias else None
+
+ def __call__(self, x, cache=None):
+ B, L, C = x.shape
+ groups, K, _ = self.weight.shape
+
+ if cache is not None:
+ x = mx.concatenate([cache, x], axis=1)
+ else:
+ x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
+
+ y = mx.conv_general(x, self.weight, groups=groups)
+
+ if self.bias is not None:
+ y = y + self.bias
+
+ return y, x[:, -K + 1 :, :]
+
+
+class MambaBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+
+ self.hidden_size = args.hidden_size
+ self.ssm_state_size = args.state_size
+ self.conv_kernel_size = args.conv_kernel
+ self.intermediate_size = args.intermediate_size
+ self.time_step_rank = int(args.time_step_rank)
+ self.use_conv_bias = args.use_conv_bias
+
+ self.in_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
+ )
+
+ self.conv1d = DepthWiseConv1d(
+ channels=self.intermediate_size,
+ kernel_size=self.conv_kernel_size,
+ bias=self.use_conv_bias,
+ padding=self.conv_kernel_size - 1,
+ )
+
+ self.x_proj = nn.Linear(
+ self.intermediate_size,
+ self.time_step_rank + 2 * self.ssm_state_size,
+ bias=False,
+ )
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
+
+ A = mx.repeat(
+ mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
+ repeats=self.intermediate_size,
+ axis=0,
+ )
+ self.A_log = mx.log(A)
+ self.D = mx.ones([self.intermediate_size])
+
+ self.out_proj = nn.Linear(
+ self.intermediate_size, self.hidden_size, bias=args.use_bias
+ )
+
+ def ssm_step(self, x, state=None):
+ A = -mx.exp(self.A_log)
+ D = self.D
+ deltaBC = self.x_proj(x)
+ delta, B, C = mx.split(
+ deltaBC,
+ indices_or_sections=[
+ self.time_step_rank,
+ self.time_step_rank + self.ssm_state_size,
+ ],
+ axis=-1,
+ )
+ delta = nn.softplus(self.dt_proj(delta))
+ new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
+ if state is not None:
+ new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
+ y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
+ y = y + D * x
+ return y, new_state
+
+ def __call__(self, x, cache):
+ B, T, D = x.shape
+ if cache is None:
+ cache = [None, None]
+
+ outputs = []
+ for t in range(T):
+ xt = x[:, t, :]
+ xz = self.in_proj(xt)
+ x_t, z_t = xz.split(indices_or_sections=2, axis=1)
+ conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
+ x_t = conv_out.squeeze(1)
+ x_t = nn.silu(x_t)
+ y_t, cache[1] = self.ssm_step(x_t, cache[1])
+ z_t = nn.silu(z_t)
+ output_t = y_t * z_t
+ output_t = self.out_proj(output_t)
+ outputs.append(output_t)
+ output = mx.stack(outputs, axis=1)
+ return output
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.mixer = MambaBlock(args)
+ self.norm = nn.RMSNorm(args.hidden_size)
+
+ def __call__(self, x: mx.array, cache):
+ return self.mixer(self.norm(x), cache) + x
+
+
+class Mamba(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
+ self.norm_f = nn.RMSNorm(args.hidden_size)
+
+ def __call__(self, x: mx.array, cache):
+ x = self.embeddings(x)
+ if cache is None:
+ cache = [None] * len(self.layers)
+ for layer, c in zip(self.layers, cache):
+ x = layer(x, c)
+ return self.norm_f(x)
+
+
+class Model(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.model_type = args.model_type
+ self.backbone = Mamba(args)
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+ def __call__(self, inputs: mx.array, cache=None):
+ B, T = inputs.shape
+
+ x = self.backbone(inputs, cache)
+
+ if self.args.tie_word_embeddings:
+ logits = self.backbone.embeddings.as_linear(x)
+ else:
+ logits = self.lm_head(x)
+
+ return logits
+
+ def sanitize(self, weights):
+ for k, v in weights.items():
+ if "conv1d.weight" in k and v.ndim == 3:
+ weights[k] = v.moveaxis(2, 1)
+ return weights
+
+ def make_cache(self, batch_size: int = 1):
+ return [MambaCache() for _ in range(len(self.layers))]
+
+ @property
+ def layers(self):
+ return self.backbone.layers
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index 71fbfaab..ab9d37aa 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -52,7 +52,6 @@ def linear_to_lora_layers(
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
-
num_layers = len(model.layers)
if num_lora_layers < 0:
@@ -140,6 +139,15 @@ def linear_to_lora_layers(
"self_attn.kv_b_proj",
]
)
+ elif model.model_type == "mamba":
+ keys = set(
+ [
+ "mixer.in_proj",
+ "mixer.x_proj",
+ "mixer.dt_proj",
+ "mixer.out_proj",
+ ]
+ )
else:
raise ValueError(f"Lora does not support {model.model_type}")
diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py
index fcf1dc33..cd7e7fd0 100644
--- a/llms/tests/test_models.py
+++ b/llms/tests/test_models.py
@@ -5,6 +5,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache, RotatingKVCache
+from mlx_lm.utils import make_kv_caches
class TestModels(unittest.TestCase):
@@ -100,13 +101,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
- kv_heads = (
- [model.n_kv_heads] * len(model.layers)
- if isinstance(model.n_kv_heads, int)
- else model.n_kv_heads
- )
- cache = [KVCache(model.head_dim, n) for n in kv_heads]
-
+ cache = make_kv_caches(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -397,6 +392,26 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
+ def test_mamba(self):
+ from mlx_lm.models import mamba
+
+ args = mamba.ModelArgs(
+ model_type="mamba",
+ vocab_size=10000,
+ use_bias=False,
+ use_conv_bias=True,
+ conv_kernel=4,
+ hidden_size=768,
+ num_hidden_layers=24,
+ state_size=16,
+ intermediate_size=1536,
+ time_step_rank=48,
+ )
+ model = mamba.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
def test_gpt2(self):
from mlx_lm.models import gpt2
From d812516d3d55c25e55238aee1e083fac5378a07f Mon Sep 17 00:00:00 2001
From: jamesm131 <20141156+jamesm131@users.noreply.github.com>
Date: Sun, 29 Sep 2024 00:21:11 +1000
Subject: [PATCH 034/188] Add /v1/models endpoint to mlx_lm.server (#984)
* Add 'models' endpoint to server
* Add test for new 'models' server endpoint
* Check hf_cache for mlx models
* update tests to check hf_cache for models
* simplify test
* doc
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/SERVER.md | 14 +++++++++++++
llms/mlx_lm/server.py | 41 +++++++++++++++++++++++++++++++++++++++
llms/tests/test_server.py | 15 ++++++++++++++
3 files changed, 70 insertions(+)
diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md
index 9c42d410..55be1c9c 100644
--- a/llms/mlx_lm/SERVER.md
+++ b/llms/mlx_lm/SERVER.md
@@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
+
+### List Models
+
+Use the `v1/models` endpoint to list available models:
+
+```shell
+curl localhost:8080/v1/models -H "Content-Type: application/json"
+```
+
+This will return a list of locally available models where each model in the
+list contains the following fields:
+
+- `"id"`: The Hugging Face repo id.
+- `"created"`: A timestamp representing the model creation time.
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index 79ac1836..f2d8b86a 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
+from huggingface_hub import scan_cache_dir
from .utils import generate_step, load
@@ -618,6 +619,46 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
+ def do_GET(self):
+ """
+ Respond to a GET request from a client.
+ """
+ if self.path == "/v1/models":
+ self.handle_models_request()
+ else:
+ self._set_completion_headers(404)
+ self.end_headers()
+ self.wfile.write(b"Not Found")
+
+ def handle_models_request(self):
+ """
+ Handle a GET request for the /v1/models endpoint.
+ """
+ self._set_completion_headers(200)
+ self.end_headers()
+
+ # Scan the cache directory for downloaded mlx models
+ hf_cache_info = scan_cache_dir()
+ downloaded_models = [
+ repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
+ ]
+
+ # Create a list of available models
+ models = [
+ {
+ "id": repo.repo_id,
+ "object": "model",
+ "created": self.created,
+ }
+ for repo in downloaded_models
+ ]
+
+ response = {"object": "list", "data": models}
+
+ response_json = json.dumps(response).encode()
+ self.wfile.write(response_json)
+ self.wfile.flush()
+
def run(
host: str,
diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py
index baea664a..cbcccfbe 100644
--- a/llms/tests/test_server.py
+++ b/llms/tests/test_server.py
@@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
+
import http
+import json
import threading
import unittest
@@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
+ def test_handle_models(self):
+ url = f"http://localhost:{self.port}/v1/models"
+ response = requests.get(url)
+ self.assertEqual(response.status_code, 200)
+ response_body = json.loads(response.text)
+ self.assertEqual(response_body["object"], "list")
+ self.assertIsInstance(response_body["data"], list)
+ self.assertGreater(len(response_body["data"]), 0)
+ model = response_body["data"][0]
+ self.assertIn("id", model)
+ self.assertEqual(model["object"], "model")
+ self.assertIn("created", model)
+
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap
From ace2bb58900dd382bbf3e6d9c58cbc3beee0261c Mon Sep 17 00:00:00 2001
From: nathan <97126670+nathanrchn@users.noreply.github.com>
Date: Sat, 28 Sep 2024 19:08:49 +0200
Subject: [PATCH 035/188] Add logits_processor option to generate_step function
(#983)
* Add logits_processor option for the generation as in huggingface transformers library
* concatenation correction
* Rename the tokens variable for clarity
* remove the logit_bias argument from generate_step method
* fix the variable name
* nits + test
* test
* add back logit bias + test
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/utils.py | 25 +++++++++++++----
llms/tests/test_generate.py | 55 +++++++++++++++++++++++++++++++++++++
2 files changed, 74 insertions(+), 6 deletions(-)
create mode 100644 llms/tests/test_generate.py
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 5621609d..16271c3e 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -154,10 +154,11 @@ def generate_step(
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
- logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
+ logit_bias: Optional[Dict[int, float]] = None,
+ logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -177,10 +178,13 @@ def generate_step(
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
- logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
+ logit_bias (dictionary, optional): Additive logit bias.
+ logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
+ A function that takes tokens and logits and returns the processed
+ logits. Default: ``None``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -188,10 +192,6 @@ def generate_step(
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
- if logit_bias:
- indices = mx.array(list(logit_bias.keys()))
- values = mx.array(list(logit_bias.values()))
- logits[:, indices] += values
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
@@ -214,6 +214,7 @@ def generate_step(
)
y = prompt
+ tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
@@ -233,11 +234,23 @@ def generate_step(
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
+ if logit_bias:
+ indices = mx.array(list(logit_bias.keys()))
+ values = mx.array(list(logit_bias.values()))
+
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
+ if logits_processor:
+ nonlocal tokens
+ tokens = mx.concat([tokens, y]) if tokens is not None else y
+ logits = logits_processor(tokens, logits)
+
+ if logit_bias:
+ logits[:, indices] += values
+
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py
new file mode 100644
index 00000000..bc969844
--- /dev/null
+++ b/llms/tests/test_generate.py
@@ -0,0 +1,55 @@
+# Copyright © 2024 Apple Inc.
+
+import unittest
+
+from mlx_lm.utils import generate, load
+
+
+class TestGenerate(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
+ cls.model, cls.tokenizer = load(HF_MODEL_PATH)
+
+ def test_generate(self):
+ # Simple test that generation runs
+ text = generate(
+ self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
+ )
+
+ def test_generate_with_logit_bias(self):
+ logit_bias = {0: 2000.0, 1: -20.0}
+ text = generate(
+ self.model,
+ self.tokenizer,
+ "hello",
+ max_tokens=5,
+ verbose=False,
+ logit_bias=logit_bias,
+ )
+ self.assertEqual(text, "!!!!!")
+
+ def test_generate_with_processor(self):
+ init_toks = self.tokenizer.encode("hello")
+
+ all_toks = None
+
+ def logits_processor(toks, logits):
+ nonlocal all_toks
+ all_toks = toks
+ return logits
+
+ generate(
+ self.model,
+ self.tokenizer,
+ "hello",
+ max_tokens=5,
+ verbose=False,
+ logits_processor=logits_processor,
+ )
+ self.assertEqual(len(all_toks), len(init_toks) + 5)
+
+
+if __name__ == "__main__":
+ unittest.main()
From 7ec2021bb99de16f89c59949439f428fc968a0d7 Mon Sep 17 00:00:00 2001
From: madroid
Date: Sun, 29 Sep 2024 01:41:36 +0800
Subject: [PATCH 036/188] LoRA: support tools(function calling) format datasets
(#995)
* LoRA: support fine-tuning tools datasets
* LoRA: Split small function
* LoRA: add tools format to lora docs
* LoRA: pre-commit fix
* Revert "LoRA: pre-commit fix"
This reverts commit b94b7e0fe7c6adfb642e1392710c027096d91d49.
* Revert "LoRA: Split small function"
This reverts commit 3f6a5f19fd8ba24bf6933c3a9bdcc66c8b29825f.
* LoRA: remove ToolsDataset
In a JSONL file, not all data is required to include the tools value.
* nit in readme
* nit in readme
* nit in readme
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/LORA.md | 68 +++++++++++++++++++++++++++++++----
llms/mlx_lm/tuner/datasets.py | 5 ++-
llms/mlx_lm/tuner/trainer.py | 4 +--
3 files changed, 66 insertions(+), 11 deletions(-)
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
index 2d9a2553..8aec89ec 100644
--- a/llms/mlx_lm/LORA.md
+++ b/llms/mlx_lm/LORA.md
@@ -160,8 +160,8 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
-Currently, `*.jsonl` files support three data formats: `chat`,
-`completions`, and `text`. Here are three examples of these formats:
+Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
+data formats. Here are examples of these formats:
`chat`:
@@ -169,6 +169,58 @@ Currently, `*.jsonl` files support three data formats: `chat`,
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
+`tools`:
+
+```jsonl
+{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
+```
+
+
+View the expanded single data tool format
+
+```jsonl
+{
+ "messages": [
+ { "role": "user", "content": "What is the weather in San Francisco?" },
+ {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "id": "call_id",
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
+ }
+ }
+ ]
+ }
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and country, eg. San Francisco, USA"
+ },
+ "format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
+ },
+ "required": ["location", "format"]
+ }
+ }
+ }
+ ]
+}
+```
+
+
+
`completions`:
```jsonl
@@ -215,11 +267,13 @@ hf_dataset:
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
-In general, for the `chat` and `completions` formats, Hugging Face [chat
-templates](https://huggingface.co/blog/chat-templates) are used. This applies
-the model's chat template by default. If the model does not have a chat
-template, then Hugging Face will use a default. For example, the final text in
-the `chat` example above with Hugging Face's default template becomes:
+In general, for the `chat`, `tools` and `completions` formats, Hugging Face
+[chat
+templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
+are used. This applies the model's chat template by default. If the model does
+not have a chat template, then Hugging Face will use a default. For example,
+the final text in the `chat` example above with Hugging Face's default template
+becomes:
```text
<|im_start|>system
diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 3d99894c..2b8abf43 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -36,7 +36,10 @@ class ChatDataset(Dataset):
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
+ messages,
+ tools=self._data[idx].get("tools", None),
+ tokenize=False,
+ add_generation_prompt=True,
)
return text
diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py
index 24fcc5c6..b15801a5 100644
--- a/llms/mlx_lm/tuner/trainer.py
+++ b/llms/mlx_lm/tuner/trainer.py
@@ -93,9 +93,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
- if b[-1] == tokenizer.eos_token_id:
- print("[WARNING] Example already has an EOS token appended")
- else:
+ if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch]
From 50e5ca81a8c06f4c49cec48795330209a885d2c9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?=
<60228478+Goekdeniz-Guelmez@users.noreply.github.com>
Date: Mon, 30 Sep 2024 02:12:47 +0200
Subject: [PATCH 037/188] Adding full finetuning (#903)
* Adding full model weights finetuning
* Updating the LORA.md and ACKNOWLEDGMENTS.md files.
* removing --use-dora and --fulll-training and adding --fine-tune-type
* some clean up
* reformating and fixing dora training
* updated CONFIG_DEFAULTS
* update config example
* update in the config example fie
* Update LORA.md
* merge and commit
* adding argument for dora linear layer
* clean up
* clean up in the example yaml file
* fix
* final fix before sending
* small addition to re md file
* fix for loading the fully trained model by saving all the files and configs correctly
* clean up
* removing the unnesesairy files
* changing lora layers back to 16
* removed max file size
* nits
* resolve merge
* some consistency changes
---------
Co-authored-by: Awni Hannun
---
ACKNOWLEDGMENTS.md | 2 +-
llms/README.md | 2 +-
llms/mlx_lm/LORA.md | 16 ++++++----
llms/mlx_lm/examples/lora_config.yaml | 7 ++--
llms/mlx_lm/fuse.py | 15 ++++-----
llms/mlx_lm/lora.py | 46 +++++++++++++++++----------
llms/mlx_lm/tuner/trainer.py | 22 ++++++-------
llms/mlx_lm/tuner/utils.py | 35 ++++++++++----------
llms/mlx_lm/utils.py | 4 +--
9 files changed, 79 insertions(+), 70 deletions(-)
diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index 2037a076..41557c29 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
-- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`.
+- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
\ No newline at end of file
diff --git a/llms/README.md b/llms/README.md
index b8e1914d..75677865 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
The `mlx-lm` package also has:
-- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
+- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
index 8aec89ec..80c25b4b 100644
--- a/llms/mlx_lm/LORA.md
+++ b/llms/mlx_lm/LORA.md
@@ -57,6 +57,9 @@ mlx_lm.lora \
--iters 600
```
+To fine-tune the full model weights, add the `--fine-tune-type full` flag.
+Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
+
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
@@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
-By default, the adapter config and weights are saved in `adapters/`. You can
-specify the output location with `--adapter-path`.
+By default, the adapter config and learned weights are saved in `adapters/`.
+You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file `.
@@ -118,7 +121,7 @@ mlx_lm.fuse --model
```
This will by default load the adapters from `adapters/`, and save the fused
-model in the path `lora_fused_model/`. All of these are configurable.
+model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
@@ -141,7 +144,7 @@ mlx_lm.fuse \
--export-gguf
```
-This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
+This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`.
## Data
@@ -301,7 +304,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
-3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
+3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
@@ -323,7 +326,7 @@ mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
- --lora-layers 4 \
+ --num-layers 4 \
--data wikisql
```
@@ -333,4 +336,5 @@ tokens-per-second, using the MLX Example
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
+
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml
index 073a5b6f..4ec9a23c 100644
--- a/llms/mlx_lm/examples/lora_config.yaml
+++ b/llms/mlx_lm/examples/lora_config.yaml
@@ -1,8 +1,12 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
+
# Whether or not to train (boolean)
train: true
+# The fine-tuning method: "lora", "dora", or "full".
+fine_tune_type: lora
+
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
@@ -51,9 +55,6 @@ max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
-# Use DoRA instead of LoRA.
-use_dora: false
-
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.
diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py
index 16457036..b0c46a74 100644
--- a/llms/mlx_lm/fuse.py
+++ b/llms/mlx_lm/fuse.py
@@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
-from .tuner.utils import apply_lora_layers, dequantize
+from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
@@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
)
parser.add_argument(
"--save-path",
- default="lora_fused_model",
+ default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
@@ -77,17 +77,14 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
- model = apply_lora_layers(model, args.adapter_path)
+ model = load_adapters(model, args.adapter_path)
fused_linears = [
- (n, m.fuse())
- for n, m in model.named_modules()
- if isinstance(
- m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
- )
+ (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
- model.update_modules(tree_unflatten(fused_linears))
+ if fused_linears:
+ model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:
print("De-quantizing model")
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 580e3d3c..69232774 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
- apply_lora_layers,
build_schedule,
linear_to_lora_layers,
+ load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
@@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
+ "fine_tune_type": "lora",
"data": "data/",
"seed": 0,
- "lora_layers": 16,
+ "num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
@@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
- "use_dora": False,
}
@@ -82,7 +82,14 @@ def build_parser():
help="Directory with {train, valid, test}.jsonl files",
)
parser.add_argument(
- "--lora-layers",
+ "--fine-tune-type",
+ type=str,
+ choices=["lora", "dora", "full"],
+ default="lora",
+ help="Type of fine-tuning to perform: lora, dora, or full.",
+ )
+ parser.add_argument(
+ "--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
@@ -107,12 +114,12 @@ def build_parser():
parser.add_argument(
"--resume-adapter-file",
type=str,
- help="Load path to resume training with the given adapters.",
+ help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
- help="Save/load path for the adapters.",
+ help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
@@ -148,9 +155,6 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
- parser.add_argument(
- "--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
- )
return parser
@@ -162,21 +166,31 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
- # Freeze all layers
model.freeze()
+ if args.fine_tune_type == "full":
+ for l in model.layers[-min(args.num_layers, 0) :]:
+ l.unfreeze()
+ elif args.fine_tune_type in ["lora", "dora"]:
+ # Convert linear layers to lora/dora layers and unfreeze in the process
+ linear_to_lora_layers(
+ model,
+ args.num_layers,
+ args.lora_parameters,
+ use_dora=(args.fine_tune_type == "dora"),
+ )
+ else:
+ raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
- # Convert linear layers to lora layers and unfreeze in the process
- linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
-
- # Resume training the given adapters.
+ # Resume from weights if provided
if args.resume_adapter_file is not None:
- print(f"Loading pretrained adapters from {args.resume_adapter_file}")
+ print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
+
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
@@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None):
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
- apply_lora_layers(model, args.adapter_path)
+ load_adapters(model, args.adapter_path)
elif args.train:
print("Training")
diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py
index b15801a5..1d934a72 100644
--- a/llms/mlx_lm/tuner/trainer.py
+++ b/llms/mlx_lm/tuner/trainer.py
@@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
+import glob
+import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
@@ -285,24 +287,18 @@ def train(
# Save adapter weights
if it % args.steps_per_save == 0:
- save_adapter(model, args.adapter_file)
+ adapter_weights = dict(tree_flatten(model.trainable_parameters()))
+ mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
- save_adapter(model, checkpoint)
+ mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
- # save final adapter weights
- save_adapter(model, args.adapter_file)
- print(f"Saved final adapter weights to {args.adapter_file}.")
-
-
-def save_adapter(
- model: nn.Module,
- adapter_file: Union[str, Path],
-):
- flattened_tree = tree_flatten(model.trainable_parameters())
- mx.save_safetensors(str(adapter_file), dict(flattened_tree))
+ # Save final weights
+ adapter_weights = dict(tree_flatten(model.trainable_parameters()))
+ mx.save_safetensors(str(args.adapter_file), adapter_weights)
+ print(f"Saved final weights to {args.adapter_file}.")
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index ab9d37aa..7c78ee91 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
def linear_to_lora_layers(
model: nn.Module,
- num_lora_layers: int,
+ num_layers: int,
config: Dict,
use_dora: bool = False,
):
@@ -45,22 +45,17 @@ def linear_to_lora_layers(
Args:
model (nn.Module): The neural network model.
- num_lora_layers (int): The number of blocks to convert to lora layers
+ num_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
- num_layers = len(model.layers)
-
- if num_lora_layers < 0:
- num_lora_layers = num_layers
-
- if num_lora_layers > num_layers:
+ if num_layers > len(model.layers):
raise ValueError(
- f"Requested {num_lora_layers} LoRA layers "
- f"but the model only has {num_layers} layers."
+ f"Requested {num_layers} LoRA layers "
+ f"but the model only has {len(model.layers)} layers."
)
def to_lora(layer):
@@ -151,7 +146,7 @@ def linear_to_lora_layers(
else:
raise ValueError(f"Lora does not support {model.model_type}")
- for l in model.layers[num_layers - num_lora_layers :]:
+ for l in model.layers[-min(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
@@ -161,9 +156,9 @@ def linear_to_lora_layers(
model.update_modules(tree_unflatten(lora_modules))
-def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
+def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
"""
- Apply LoRA layers to the model.
+ Load any fine-tuned adapters / layers.
Args:
model (nn.Module): The neural network model.
@@ -177,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
- linear_to_lora_layers(
- model,
- config.lora_layers,
- config.lora_parameters,
- getattr(config, "use_dora", False),
- )
+ fine_tune_type = getattr(config, "fine_tune_type", "lora")
+ if fine_tune_type != "full":
+ linear_to_lora_layers(
+ model,
+ config.num_layers,
+ config.lora_parameters,
+ use_dora=(fine_tune_type == "dora"),
+ )
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 16271c3e..9411138d 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -21,8 +21,8 @@ from transformers import PreTrainedTokenizer
from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
-from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
+from .tuner.utils import load_adapters
# Constants
MODEL_REMAPPING = {
@@ -515,7 +515,7 @@ def load(
model = load_model(model_path, lazy, model_config)
if adapter_path is not None:
- model = apply_lora_layers(model, adapter_path)
+ model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
From aa1c8abdc67be20a1efcacb0e3102ac732a87967 Mon Sep 17 00:00:00 2001
From: madroid
Date: Mon, 30 Sep 2024 22:36:21 +0800
Subject: [PATCH 038/188] LoRA: Support HuggingFace dataset via data parameter
(#996)
* LoRA: support huggingface dataset via `data` argument
* LoRA: Extract the load_custom_hf_dataset function
* LoRA: split small functions
* fix spelling errors
* handle load hf dataset error
* fix pre-commit lint
* update data argument help
* nits and doc
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/LORA.md | 8 ++-
llms/mlx_lm/lora.py | 5 +-
llms/mlx_lm/tuner/datasets.py | 131 +++++++++++++++++++++-------------
3 files changed, 93 insertions(+), 51 deletions(-)
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
index 80c25b4b..2d0dcf60 100644
--- a/llms/mlx_lm/LORA.md
+++ b/llms/mlx_lm/LORA.md
@@ -251,7 +251,13 @@ To use Hugging Face datasets, first install the `datasets` package:
pip install datasets
```
-Specify the Hugging Face dataset arguments in a YAML config. For example:
+If the Hugging Face dataset is already in a supported format, you can specify
+it on the command line. For example, pass `--data mlx-community/wikisql` to
+train on the pre-formatted WikiwSQL data.
+
+Otherwise, provide a mapping of keys in the dataset to the features MLX LM
+expects. Use a YAML config to specify the Hugging Face dataset arguments. For
+example:
```
hf_dataset:
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 69232774..c96e75a7 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -79,7 +79,10 @@ def build_parser():
parser.add_argument(
"--data",
type=str,
- help="Directory with {train, valid, test}.jsonl files",
+ help=(
+ "Directory with {train, valid, test}.jsonl files or the name "
+ "of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
+ ),
)
parser.add_argument(
"--fine-tune-type",
diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 2b8abf43..20b32eff 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -76,17 +76,14 @@ class CompletionsDataset(Dataset):
return text
-def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
- # Return empty dataset for non-existent paths
- if not path.exists():
- return []
- with open(path, "r") as fid:
- data = [json.loads(l) for l in fid]
- if "messages" in data[0]:
+def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
+ sample = data[0]
+
+ if "messages" in sample:
return ChatDataset(data, tokenizer)
- elif "prompt" in data[0] and "completion" in data[0]:
+ elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
- elif "text" in data[0]:
+ elif "text" in sample:
return Dataset(data)
else:
raise ValueError(
@@ -95,54 +92,90 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
)
-def load_dataset(args, tokenizer: PreTrainedTokenizer):
- if getattr(args, "hf_dataset", None) is not None:
- import datasets
+def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
+ def load_subset(path):
+ if not path.exists():
+ return []
+ with open(path, "r") as fid:
+ data = [json.loads(l) for l in fid]
+ return create_dataset(data, tokenizer)
- hf_args = args.hf_dataset
- dataset_name = hf_args["name"]
- print(f"Loading Hugging Face dataset {dataset_name}.")
- text_feature = hf_args.get("text_feature")
- prompt_feature = hf_args.get("prompt_feature")
- completion_feature = hf_args.get("completion_feature")
+ names = ("train", "valid", "test")
+ train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
+ return train, valid, test
- def create_hf_dataset(split: str = None):
- ds = datasets.load_dataset(
- dataset_name,
- split=split,
- **hf_args.get("config", {}),
- )
- if prompt_feature and completion_feature:
- return CompletionsDataset(
- ds, tokenizer, prompt_feature, completion_feature
- )
- elif text_feature:
- return Dataset(train_ds, text_key=text_feature)
- else:
- raise ValueError(
- "Specify either a prompt and completion feature or a text "
- "feature for the Hugging Face dataset."
- )
- if args.train:
- train_split = hf_args.get("train_split", "train[:80%]")
- valid_split = hf_args.get("valid_split", "train[-10%:]")
- train = create_hf_dataset(split=train_split)
- valid = create_hf_dataset(split=valid_split)
- else:
- train, valid = [], []
- if args.test:
- test = create_hf_dataset(split=hf_args.get("test_split"))
- else:
- test = []
+def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
+ from datasets import exceptions, load_dataset
+
+ try:
+ dataset = load_dataset(data_id)
- else:
names = ("train", "valid", "test")
- data_path = Path(args.data)
train, valid, test = [
- create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
+ create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
+ for n in names
]
+
+ except exceptions.DatasetNotFoundError:
+ raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
+
+ return train, valid, test
+
+
+def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
+ import datasets
+
+ hf_args = args.hf_dataset
+ dataset_name = hf_args["name"]
+ print(f"Loading Hugging Face dataset {dataset_name}.")
+ text_feature = hf_args.get("text_feature")
+ prompt_feature = hf_args.get("prompt_feature")
+ completion_feature = hf_args.get("completion_feature")
+
+ def create_hf_dataset(split: str = None):
+ ds = datasets.load_dataset(
+ dataset_name,
+ split=split,
+ **hf_args.get("config", {}),
+ )
+ if prompt_feature and completion_feature:
+ return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
+ elif text_feature:
+ return Dataset(train_ds, text_key=text_feature)
+ else:
+ raise ValueError(
+ "Specify either a prompt and completion feature or a text "
+ "feature for the Hugging Face dataset."
+ )
+
+ if args.train:
+ train_split = hf_args.get("train_split", "train[:80%]")
+ valid_split = hf_args.get("valid_split", "train[-10%:]")
+ train = create_hf_dataset(split=train_split)
+ valid = create_hf_dataset(split=valid_split)
+ else:
+ train, valid = [], []
+ if args.test:
+ test = create_hf_dataset(split=hf_args.get("test_split"))
+ else:
+ test = []
+
+ return train, valid, test
+
+
+def load_dataset(args, tokenizer: PreTrainedTokenizer):
+ if getattr(args, "hf_dataset", None) is not None:
+ train, valid, test = load_custom_hf_dataset(args, tokenizer)
+ else:
+ data_path = Path(args.data)
+ if data_path.exists():
+ train, valid, test = load_local_dataset(data_path, tokenizer)
+ else:
+ print(f"Loading Hugging Face dataset {args.data}.")
+ train, valid, test = load_hf_dataset(args.data, tokenizer)
+
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
From 418d9a5511fbefced1229cfed1d311a771aa5db5 Mon Sep 17 00:00:00 2001
From: Zai Thottakath
Date: Mon, 30 Sep 2024 10:01:11 -0500
Subject: [PATCH 039/188] Feature: QDoRA (#891)
* feat: QDoRA with tests and a small bug fix for recalculation of self.m
* some simplifications and fixes
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/tuner/dora.py | 51 ++++++++++---
llms/tests/test_finetune.py | 143 +++++++++++++++++++++++++++++++++++-
2 files changed, 183 insertions(+), 11 deletions(-)
diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py
index bd2dfb01..aba1f6f4 100644
--- a/llms/mlx_lm/tuner/dora.py
+++ b/llms/mlx_lm/tuner/dora.py
@@ -14,10 +14,11 @@ class DoRALinear(nn.Module):
dropout: float = 0.0,
scale: float = 20.0,
):
- # TODO support quantized weights in DoRALinear
+ # TODO remove when input_dims and output_dims are attributes
+ # on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
- raise ValueError("DoRALinear does not yet support quantization.")
+ input_dims *= 32 // linear.bits
dora_lin = DoRALinear(
input_dims=input_dims,
output_dims=output_dims,
@@ -31,13 +32,13 @@ class DoRALinear(nn.Module):
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
- weight = linear.weight
+ weight = self._dequantized_weight()
- # Use the same type as the linear weight if not quantized
+ # Use the same type as the linear weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
- fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
+ fused_linear = nn.Linear(input_dims, output_dims, bias=False)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
@@ -47,6 +48,13 @@ class DoRALinear(nn.Module):
if bias:
fused_linear.bias = linear.bias
+
+ if self._is_quantized() and not de_quantize:
+ fused_linear = nn.QuantizedLinear.from_linear(
+ fused_linear,
+ linear.group_size,
+ linear.bits,
+ )
return fused_linear
def __init__(
@@ -76,22 +84,45 @@ class DoRALinear(nn.Module):
)
self.lora_b = mx.zeros(shape=(r, output_dims))
- def set_linear(self, linear: nn.Linear):
+ def set_linear(self, linear):
+ """
+ Set the self.linear layer and recompute self.m.
+ """
self.linear = linear
- self.m = mx.linalg.norm(self.linear.weight, axis=1)
+ self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
+
+ def _dequantized_weight(self):
+ """
+ Return the weight of linear layer and dequantize it if is quantized
+ """
+ weight = self.linear.weight
+ if self._is_quantized():
+ weight = mx.dequantize(
+ weight,
+ self.linear.scales,
+ self.linear.biases,
+ self.linear.group_size,
+ self.linear.bits,
+ )
+ return weight
+
+ def _is_quantized(self):
+ return isinstance(self.linear, nn.QuantizedLinear)
def __call__(self, x):
# Regular LoRA (without a bias)
- y = x @ self.linear.weight.T
+ w = self._dequantized_weight()
+ y = x @ w.T
+
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
- adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
+ adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
- out = (self.m / denom) * out
+ out = (self.m / denom).astype(x.dtype) * out
if "bias" in self.linear:
out = out + self.linear.bias
diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py
index 289b8cfb..107be092 100644
--- a/llms/tests/test_finetune.py
+++ b/llms/tests/test_finetune.py
@@ -11,7 +11,7 @@ import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
-from mlx_lm.tuner.dora import DoRAEmbedding
+from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@@ -164,6 +164,147 @@ class TestDora(unittest.TestCase):
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
+ def test_llama(self):
+ from mlx_lm.models import llama
+
+ hidden_size = 1024
+ intermediate_size = 2048
+ args = llama.ModelArgs(
+ model_type="llama",
+ hidden_size=hidden_size,
+ num_hidden_layers=4,
+ intermediate_size=intermediate_size,
+ num_attention_heads=4,
+ rms_norm_eps=1e-5,
+ vocab_size=10_000,
+ )
+
+ dora_layers = 4
+
+ def check_config(params):
+ n_keys = 2
+ if "keys" in params:
+ n_keys = len(params["keys"])
+ model = llama.Model(args)
+ model.freeze()
+ tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
+ trainable_params = sum(
+ v.size for _, v in tree_flatten(model.trainable_parameters())
+ )
+ self.assertEqual(
+ trainable_params,
+ dora_layers
+ * (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
+ )
+
+ params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
+ check_config(params)
+
+ params["rank"] = 1
+ check_config(params)
+
+ params["keys"] = ["self_attn.k_proj"]
+ check_config(params)
+
+ def test_dora_m_parameter(self):
+ dora_lin = DoRALinear(input_dims=100, output_dims=100)
+ self.assertTrue(
+ mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
+ )
+
+ # Recomputes m when changing Linear
+ inital_m = dora_lin.m
+ lin = nn.Linear(10, 10)
+ dora_lin.set_linear(lin)
+ self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
+
+ # Works with quantized weights
+ quantized_linear = nn.QuantizedLinear(512, 512)
+ dora_lin.set_linear(quantized_linear)
+ dequantized_weight = mx.dequantize(
+ quantized_linear.weight,
+ quantized_linear.scales,
+ quantized_linear.biases,
+ quantized_linear.group_size,
+ quantized_linear.bits,
+ )
+ self.assertTrue(
+ mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
+ )
+
+ def test_dora_from_linear(self):
+ in_dims = 256
+ out_dims = 256
+ r = 4
+
+ linear = nn.Linear(in_dims, out_dims)
+ dora_lin = DoRALinear.from_base(linear, r)
+ self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
+ self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
+ self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
+ self.assertEqual(dora_lin.m.shape, (out_dims,))
+
+ quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
+ dequantized_weight = mx.dequantize(
+ quantized_linear.weight,
+ quantized_linear.scales,
+ quantized_linear.biases,
+ quantized_linear.group_size,
+ quantized_linear.bits,
+ )
+ dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
+ self.assertTrue(
+ mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
+ )
+ self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
+ self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
+ self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
+
+ def test_dora_to_linear(self):
+ in_dims = 256
+ out_dims = 256
+ r = 4
+
+ linear = nn.Linear(in_dims, out_dims, bias=True)
+ dora_lin = DoRALinear.from_base(linear, r)
+ to_linear = dora_lin.fuse()
+ self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
+ self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
+
+ def dequantize_weight(quantized_linear):
+ return mx.dequantize(
+ quantized_linear.weight,
+ quantized_linear.scales,
+ quantized_linear.biases,
+ quantized_linear.group_size,
+ quantized_linear.bits,
+ )
+
+ quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
+ dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
+ # Dequantize
+ to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
+ self.assertTrue(
+ mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
+ )
+ self.assertTrue(
+ mx.allclose(
+ dequantize_weight(quantized_linear), to_linear_from_quantized.weight
+ )
+ )
+
+ def test_dora_dtype(self):
+ in_dims = 256
+ out_dims = 256
+ r = 4
+
+ linear = nn.Linear(in_dims, out_dims, bias=True)
+ linear.set_dtype(mx.float16)
+ dora_lin = DoRALinear.from_base(linear, r)
+
+ x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
+ self.assertEqual(dora_lin(x).dtype, mx.float16)
+
class TestScheduleConfig(unittest.TestCase):
def test_join(self):
From 0866e23a67c3d5ab8dbac352b112f13967103942 Mon Sep 17 00:00:00 2001
From: nathan <97126670+nathanrchn@users.noreply.github.com>
Date: Mon, 30 Sep 2024 17:49:03 +0200
Subject: [PATCH 040/188] repetiton_penalty and logits_bias just using
logits_processors (#1004)
* refactor of repetition_penalty and logits_bias to use logits_processor
* nits
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/utils.py | 66 ++++++++++++++++++-------------------
llms/tests/test_generate.py | 2 +-
2 files changed, 33 insertions(+), 35 deletions(-)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 9411138d..54a96457 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -101,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path
-def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
+def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
@@ -109,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
Args:
logits (mx.array): The logits produced by the language model.
- generated_tokens (any): A list of N previous tokens.
+ tokens (mx.array): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
- if len(generated_tokens) > 0:
- indices = mx.array([token for token in generated_tokens])
- selected_logits = logits[:, indices]
+ if len(tokens) > 0:
+ selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
- logits[:, indices] = selected_logits
+ logits[:, tokens] = selected_logits
return logits
@@ -158,7 +157,7 @@ def generate_step(
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
- logits_processor: Optional[Callable[[mx.array, mx.array], mx.array]] = None,
+ logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -182,8 +181,8 @@ def generate_step(
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
- logits_processor (Callable[[mx.array, mx.array], mx.array], optional):
- A function that takes tokens and logits and returns the processed
+ logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
+ A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
Yields:
@@ -213,6 +212,27 @@ def generate_step(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
+ logits_processor = logits_processor or []
+
+ if repetition_penalty:
+
+ def repetition_penalty_processor(tokens, logits):
+ return apply_repetition_penalty(
+ logits, tokens[-repetition_context_size:], repetition_penalty
+ )
+
+ logits_processor.append(repetition_penalty_processor)
+
+ if logit_bias:
+ indices = mx.array(list(logit_bias.keys()))
+ values = mx.array(list(logit_bias.values()))
+
+ def logit_bias_processor(_, logits):
+ logits[:, indices] += values
+ return logits
+
+ logits_processor.append(logit_bias_processor)
+
y = prompt
tokens = None
@@ -229,40 +249,18 @@ def generate_step(
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
- repetition_context = prompt.tolist()
-
- if repetition_context_size:
- repetition_context = repetition_context[-repetition_context_size:]
-
- if logit_bias:
- indices = mx.array(list(logit_bias.keys()))
- values = mx.array(list(logit_bias.values()))
-
def _step(y):
- nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
- logits = logits_processor(tokens, logits)
- if logit_bias:
- logits[:, indices] += values
+ for processor in logits_processor:
+ logits = processor(tokens, logits)
- if repetition_penalty:
- logits = apply_repetition_penalty(
- logits, repetition_context, repetition_penalty
- )
- y, logprobs = sample(logits)
- repetition_context.append(y.item())
- else:
- y, logprobs = sample(logits)
-
- if repetition_context_size:
- if len(repetition_context) > repetition_context_size:
- repetition_context = repetition_context[-repetition_context_size:]
+ y, logprobs = sample(logits)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py
index bc969844..68f1670b 100644
--- a/llms/tests/test_generate.py
+++ b/llms/tests/test_generate.py
@@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello",
max_tokens=5,
verbose=False,
- logits_processor=logits_processor,
+ logits_processor=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
From 36c1d8e8dcd42a7104ab532ef4ae003c08f41c5f Mon Sep 17 00:00:00 2001
From: madroid
Date: Thu, 3 Oct 2024 03:36:07 +0800
Subject: [PATCH 041/188] Server: support function calling (#1003)
---
llms/mlx_lm/server.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index f2d8b86a..42962b54 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -594,6 +594,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
prompt = self.tokenizer.apply_chat_template(
body["messages"],
+ body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
From 9bc53fc2100319d59179a179efe34346372772cf Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Wed, 2 Oct 2024 13:13:33 -0700
Subject: [PATCH 042/188] convert (#1006)
---
whisper/convert.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/whisper/convert.py b/whisper/convert.py
index da7195e0..cdd50bc5 100644
--- a/whisper/convert.py
+++ b/whisper/convert.py
@@ -35,6 +35,8 @@ _MODELS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
+ "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
+ "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
@@ -52,6 +54,8 @@ _ALIGNMENT_HEADS = {
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
+ "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
+ "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}
From fca087be4906332ccc086013e96da93981e57037 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Mon, 7 Oct 2024 20:45:51 -0700
Subject: [PATCH 043/188] More cache improvements (#1015)
* fix rotating kv cache for chat use case
* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat
* nit in chat
* fix tests
* fix tests
* fix tests
* docs
* chat command
* comments + docs
* Define meta_state on all Cache implementations
* fixes + trim_prompt_cache api
* fix default model
---------
Co-authored-by: Angelos Katharopoulos
---
.gitignore | 3 +
llms/README.md | 39 ++-
llms/mlx_lm/_version.py | 2 +-
llms/mlx_lm/cache_prompt.py | 19 +-
llms/mlx_lm/chat.py | 82 ++++++
llms/mlx_lm/examples/chat.py | 53 ++++
llms/mlx_lm/examples/generate_response.py | 2 +
llms/mlx_lm/generate.py | 67 ++---
llms/mlx_lm/models/base.py | 155 +---------
llms/mlx_lm/models/cache.py | 333 ++++++++++++++++++++++
llms/mlx_lm/models/cohere.py | 14 +-
llms/mlx_lm/models/dbrx.py | 16 +-
llms/mlx_lm/models/deepseek.py | 24 +-
llms/mlx_lm/models/deepseek_v2.py | 36 +--
llms/mlx_lm/models/gemma.py | 14 +-
llms/mlx_lm/models/gemma2.py | 20 +-
llms/mlx_lm/models/gpt2.py | 14 +-
llms/mlx_lm/models/gpt_bigcode.py | 14 +-
llms/mlx_lm/models/gpt_neox.py | 14 +-
llms/mlx_lm/models/internlm2.py | 14 +-
llms/mlx_lm/models/llama.py | 18 +-
llms/mlx_lm/models/mamba.py | 18 +-
llms/mlx_lm/models/minicpm.py | 14 +-
llms/mlx_lm/models/mixtral.py | 14 +-
llms/mlx_lm/models/nemotron.py | 18 +-
llms/mlx_lm/models/olmo.py | 18 +-
llms/mlx_lm/models/openelm.py | 14 +-
llms/mlx_lm/models/phi.py | 12 +-
llms/mlx_lm/models/phi3.py | 16 +-
llms/mlx_lm/models/phi3small.py | 21 +-
llms/mlx_lm/models/phimoe.py | 9 +-
llms/mlx_lm/models/phixtral.py | 12 +-
llms/mlx_lm/models/plamo.py | 26 +-
llms/mlx_lm/models/qwen.py | 13 +-
llms/mlx_lm/models/qwen2.py | 16 +-
llms/mlx_lm/models/qwen2_moe.py | 16 +-
llms/mlx_lm/models/recurrent_gemma.py | 126 ++------
llms/mlx_lm/models/stablelm.py | 13 +-
llms/mlx_lm/models/starcoder2.py | 16 +-
llms/mlx_lm/utils.py | 51 +---
llms/setup.py | 1 +
llms/tests/test_models.py | 225 ++++++++++++++-
llms/tests/test_prompt_cache.py | 220 ++++++++++++++
43 files changed, 1151 insertions(+), 691 deletions(-)
create mode 100644 llms/mlx_lm/chat.py
create mode 100644 llms/mlx_lm/examples/chat.py
create mode 100644 llms/mlx_lm/models/cache.py
create mode 100644 llms/tests/test_prompt_cache.py
diff --git a/.gitignore b/.gitignore
index f3dfe929..45445fc8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
+# Vim
+*.swp
+
# Distribution / packaging
.Python
build/
diff --git a/llms/README.md b/llms/README.md
index 75677865..20863041 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -20,6 +20,31 @@ The `mlx-lm` package also has:
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
+### Quick Start
+
+To generate text with an LLM use:
+
+```bash
+mlx_lm.generate --prompt "Hi!"
+```
+
+To chat with an LLM use:
+
+```bash
+mlx_lm.chat
+```
+
+This will give you a chat REPL that you can use to interact with the LLM. The
+chat context is preserved during the lifetime of the REPL.
+
+Commands in `mlx-lm` typically take command line options which let you specify
+the model, sampling parameters, and more. Use `-h` to see a list of available
+options for a command, e.g.:
+
+```bash
+mlx_lm.generate -h
+```
+
### Python API
You can use `mlx-lm` as a module:
@@ -138,7 +163,7 @@ mlx_lm.convert \
### Long Prompts and Generations
-MLX LM has some tools to scale efficiently to long prompts and generations:
+`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
@@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
- --kv-cache-file mistral_prompt.safetensors
+ --prompt-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
- --kv-cache-file mistral_prompt.safetensors \
+ --prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
@@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
+Prompt caching can also be used in the Python API in order to to avoid
+recomputing the prompt. This is useful in multi-turn dialogues or across
+requests that use the same context. See the
+[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
+for more usage details.
+
### Supported Models
-MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
+`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py
index 8110c823..70239db6 100644
--- a/llms/mlx_lm/_version.py
+++ b/llms/mlx_lm/_version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.18.2"
+__version__ = "0.19.1"
diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py
index 9829efb4..04e75a3e 100644
--- a/llms/mlx_lm/cache_prompt.py
+++ b/llms/mlx_lm/cache_prompt.py
@@ -7,13 +7,14 @@ import time
import mlx.core as mx
-from .utils import load, make_kv_caches
+from .models.cache import make_prompt_cache, save_prompt_cache
+from .utils import load
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
- description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
+ description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
@@ -60,7 +61,9 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
)
parser.add_argument(
- "--kv-cache-file", help="The file to save the KV caches in", required=True
+ "--prompt-cache-file",
+ help="The file to save the prompt cache in",
+ required=True,
)
parser.add_argument(
"--prompt",
@@ -115,7 +118,7 @@ def main():
else:
prompt = args.prompt
- cache = make_kv_caches(model, args.max_kv_size)
+ cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
@@ -137,16 +140,12 @@ def main():
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
- cache_dict = {}
- for i, c in enumerate(cache):
- cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
- cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
- metadata["max_kv_size"] = str(args.max_kv_size)
- mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
+ print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
+ save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":
diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py
new file mode 100644
index 00000000..7968a868
--- /dev/null
+++ b/llms/mlx_lm/chat.py
@@ -0,0 +1,82 @@
+# Copyright © 2023-2024 Apple Inc.
+
+import argparse
+import json
+
+import mlx.core as mx
+
+from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
+from .utils import load, stream_generate
+
+DEFAULT_TEMP = 0.0
+DEFAULT_TOP_P = 1.0
+DEFAULT_SEED = 0
+DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
+
+
+def setup_arg_parser():
+ """Set up and return the argument parser."""
+ parser = argparse.ArgumentParser(description="Chat with an LLM")
+ parser.add_argument(
+ "--model",
+ type=str,
+ help="The path to the local model directory or Hugging Face repo.",
+ default=DEFAULT_MODEL,
+ )
+ parser.add_argument(
+ "--adapter-path",
+ type=str,
+ help="Optional path for the trained adapter weights and config.",
+ )
+ parser.add_argument(
+ "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
+ )
+ parser.add_argument(
+ "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
+ )
+ parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
+ parser.add_argument(
+ "--max-kv-size",
+ type=int,
+ help="Set the maximum key-value cache size",
+ default=None,
+ )
+ return parser
+
+
+def main():
+ parser = setup_arg_parser()
+ args = parser.parse_args()
+
+ mx.random.seed(args.seed)
+
+ model, tokenizer = load(
+ args.model,
+ adapter_path=args.adapter_path,
+ tokenizer_config={"trust_remote_code": True},
+ )
+
+ print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
+ prompt_cache = make_prompt_cache(model, args.max_kv_size)
+ while True:
+ query = input(">> ")
+ if query == "q":
+ break
+ messages = [{"role": "user", "content": query}]
+ prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ for response in stream_generate(
+ model,
+ tokenizer,
+ prompt,
+ temp=args.temp,
+ top_p=args.top_p,
+ prompt_cache=prompt_cache,
+ ):
+ print(response, flush=True, end="")
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py
new file mode 100644
index 00000000..3bf01688
--- /dev/null
+++ b/llms/mlx_lm/examples/chat.py
@@ -0,0 +1,53 @@
+# Copyright © 2024 Apple Inc.
+
+"""
+An example of a multi-turn chat with prompt caching.
+"""
+
+from mlx_lm import generate, load
+from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
+
+model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
+
+# Make the initial prompt cache for the model
+prompt_cache = make_prompt_cache(model)
+
+# User turn
+prompt = "Hi my name is ."
+messages = [{"role": "user", "content": prompt}]
+prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+)
+
+# Assistant response
+response = generate(
+ model,
+ tokenizer,
+ prompt=prompt,
+ verbose=True,
+ temp=0.0,
+ prompt_cache=prompt_cache,
+)
+
+# User turn
+prompt = "What's my name?"
+messages = [{"role": "user", "content": prompt}]
+prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+)
+
+# Assistant response
+response = generate(
+ model,
+ tokenizer,
+ prompt=prompt,
+ verbose=True,
+ temp=0.0,
+ prompt_cache=prompt_cache,
+)
+
+# Save the prompt cache to disk to reuse it at a later time
+save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
+
+# Load the prompt cache from disk
+prompt_cache = load_prompt_cache("mistral_prompt.safetensors")
diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py
index af599c1b..25730617 100644
--- a/llms/mlx_lm/examples/generate_response.py
+++ b/llms/mlx_lm/examples/generate_response.py
@@ -1,3 +1,5 @@
+# Copyright © 2024 Apple Inc.
+
from mlx_lm import generate, load
# Specify the checkpoint
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 537bd853..0bf98ab2 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -6,13 +6,15 @@ import sys
import mlx.core as mx
+from .models.cache import load_prompt_cache
from .utils import generate, load
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
-DEFAULT_TEMP = 0.6
+DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
+DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def str2bool(string):
@@ -25,7 +27,11 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
- help="The path to the local model directory or Hugging Face repo.",
+ help=(
+ "The path to the local model directory or Hugging Face repo. "
+ f"If no model is specified, then {DEFAULT_MODEL} is used."
+ ),
+ default=None,
)
parser.add_argument(
"--adapter-path",
@@ -96,7 +102,7 @@ def setup_arg_parser():
default=None,
)
parser.add_argument(
- "--kv-cache-file",
+ "--prompt-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
@@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
-def load_kv_cache_from_file(kv_cache_file):
- if kv_cache_file is None:
- return None, None
-
- kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
- cache_per_layer = {}
- for k, x in kv_cache.items():
- layer, kv_type = k.split("_")
- if layer not in cache_per_layer:
- cache_per_layer[layer] = {}
- cache_per_layer[layer][kv_type] = x
-
- cache_history = [None] * len(cache_per_layer)
- for layer, c in cache_per_layer.items():
- cache_history[int(layer)] = (c["keys"], c["values"])
- return cache_history, metadata
-
-
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@@ -158,22 +146,33 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
- # Load the kv cache and metadata if a kv cache file is provided
- cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
+ # Load the prompt cache and metadata if a cache file is provided
+ using_cache = args.prompt_cache_file is not None
+ if using_cache:
+ prompt_cache, metadata = load_prompt_cache(
+ args.prompt_cache_file, return_metadata=True
+ )
# Building tokenizer_config
tokenizer_config = (
- {} if cache_history is None else json.loads(metadata["tokenizer_config"])
+ {} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
- # If no model path is provided then use the one in the kv cache history
model_path = args.model
- if cache_history is not None and model_path is None:
- model_path = metadata["model"]
+ if using_cache:
+ if model_path is None:
+ model_path = metadata["model"]
+ elif model_path != metadata["model"]:
+ raise ValueError(
+ f"Providing a different model ({model_path}) than that "
+ f"used to create the prompt cache ({metadata['model']}) "
+ "is an error."
+ )
+ model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
@@ -184,7 +183,7 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
- elif cache_history is not None:
+ elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
@@ -203,7 +202,7 @@ def main():
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
- if cache_history is not None:
+ if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}],
tokenize=False,
@@ -217,12 +216,6 @@ def main():
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
- # Determine the max kv size from the kv cache or passed arguments
- max_kv_size = args.max_kv_size
- if cache_history is not None:
- max_kv_size = metadata["max_kv_size"]
- max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
-
response = generate(
model,
tokenizer,
@@ -232,8 +225,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
- max_kv_size=max_kv_size,
- cache_history=cache_history,
+ max_kv_size=args.max_kv_size,
+ prompt_cache=prompt_cache if using_cache else None,
)
if not args.verbose:
print(response)
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index dc19dd05..3628a808 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -2,145 +2,9 @@
import inspect
from dataclasses import dataclass
-from typing import Any, List, Optional
+from typing import Any, Optional
import mlx.core as mx
-import mlx.nn as nn
-
-
-class KVCache:
-
- def __init__(self, head_dim, n_kv_heads):
- self.n_kv_heads = n_kv_heads
- if isinstance(head_dim, int):
- self.k_head_dim = self.v_head_dim = head_dim
- elif isinstance(head_dim, tuple) and len(head_dim) == 2:
- self.k_head_dim, self.v_head_dim = head_dim
- else:
- raise ValueError("head_dim must be an int or a tuple of two ints")
- self.keys = None
- self.values = None
- self.offset = 0
- self.step = 256
-
- def update_and_fetch(self, keys, values):
- prev = self.offset
- if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
- B = keys.shape[0]
- n_steps = (self.step + keys.shape[2] - 1) // self.step
- k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
- v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
- new_k = mx.zeros(k_shape, keys.dtype)
- new_v = mx.zeros(v_shape, values.dtype)
- if self.keys is not None:
- if prev % self.step != 0:
- self.keys = self.keys[..., :prev, :]
- self.values = self.values[..., :prev, :]
- self.keys = mx.concatenate([self.keys, new_k], axis=2)
- self.values = mx.concatenate([self.values, new_v], axis=2)
- else:
- self.keys, self.values = new_k, new_v
-
- self.offset += keys.shape[2]
- self.keys[..., prev : self.offset, :] = keys
- self.values[..., prev : self.offset, :] = values
- return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
-
- @property
- def state(self):
- return self.keys, self.values
-
-
-class RotatingKVCache:
-
- def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
- self.n_kv_heads = n_kv_heads
- if isinstance(head_dim, int):
- self.k_head_dim = self.v_head_dim = head_dim
- elif isinstance(head_dim, tuple) and len(head_dim) == 2:
- self.k_head_dim, self.v_head_dim = head_dim
- else:
- raise ValueError("head_dim must be an int or a tuple of two ints")
- self.keep = keep
- self.keys = None
- self.values = None
- self.offset = 0
- self.max_size = max_size
- self.step = step
- self._idx = 0
-
- def _trim(self, trim_size, v, append=None):
- to_cat = []
- if trim_size > 0:
- to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
- else:
- to_cat = [v]
- if append is not None:
- to_cat.append(append)
- return mx.concatenate(to_cat, axis=2)
-
- def update_and_fetch(self, keys, values):
- prev = self.offset
- B, _, S = keys.shape[:3]
-
- # Prefill mode
- if S > 1:
- if self.keys is None:
- self.keys = keys
- self.values = values
- else:
- # The largest size is self.max_size + S - 1 to ensure
- # every token gets at least self.max_size context
- trim_size = self.keys.shape[2] - self.max_size + 1
- self.keys = self._trim(trim_size, self.keys, keys)
- self.values = self._trim(trim_size, self.values, values)
- self.offset += S
- self._idx = self.keys.shape[2]
- return self.keys, self.values
-
- # Generation mode
- # May not have hit the max size yet, so potentially
- # keep growing the cache
- if self.keys is None or (
- prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
- ):
- new_size = min(self.step, self.max_size - prev)
- k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
- v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
- new_k = mx.zeros(k_shape, keys.dtype)
- new_v = mx.zeros(v_shape, values.dtype)
- if self.keys is not None:
- self.keys = mx.concatenate([self.keys, new_k], axis=2)
- self.values = mx.concatenate([self.values, new_v], axis=2)
- else:
- self.keys, self.values = new_k, new_v
- self._idx = prev
-
- # Trim if needed
- trim_size = self.keys.shape[2] - self.max_size
- if trim_size > 0:
- self.keys = self._trim(trim_size, self.keys)
- self.values = self._trim(trim_size, self.values)
- self._idx = self.max_size
-
- # Rotate
- if self._idx == self.max_size:
- self._idx = self.keep
-
- # Assign
- self.keys[..., self._idx : self._idx + 1, :] = keys
- self.values[..., self._idx : self._idx + 1, :] = values
- self.offset += 1
- self._idx += 1
-
- # If the buffer is not full, slice off the end
- if self.offset < self.max_size:
- return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
- return self.keys, self.values
-
- @property
- def state(self):
- return self.keys, self.values
@dataclass
@@ -156,25 +20,30 @@ class BaseModelArgs:
)
-def create_additive_causal_mask(N: int, offset: int = 0):
+def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
- mask = linds[:, None] < rinds[None]
+ linds = linds[:, None]
+ rinds = rinds[None]
+ mask = linds < rinds
+ if window_size is not None:
+ mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
+ window_size = None
+ offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
- if isinstance(c, RotatingKVCache):
+ if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
+ window_size = c.max_size
else:
offset = c.offset
- else:
- offset = 0
- mask = create_additive_causal_mask(T, offset)
+ mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py
new file mode 100644
index 00000000..b06422e5
--- /dev/null
+++ b/llms/mlx_lm/models/cache.py
@@ -0,0 +1,333 @@
+# Copyright © 2023-2024 Apple Inc.
+
+from typing import Any, Dict, List, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.utils import tree_flatten, tree_unflatten
+
+
+def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
+ """
+ Construct the model's cache for use when cgeneration.
+
+ This function will defer the cache construction to the model if it has a
+ ``make_cache`` method, otherwise it will make a default KV cache.
+
+ Args:
+ model (nn.Module): The language model.
+ max_kv_size (Optional[int]): If provided and the model does not have a
+ ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
+ size of ``max_kv_size``
+ """
+ if hasattr(model, "make_cache"):
+ return model.make_cache()
+
+ num_layers = len(model.layers)
+ if max_kv_size is not None:
+ return [
+ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
+ ]
+ else:
+ return [KVCache() for _ in range(num_layers)]
+
+
+def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
+ """
+ Save a pre-computed prompt cache to a file.
+
+ Args:
+ file_name (str): The ``.safetensors`` file name.
+ cache (List[Any]): The model state.
+ metadata (Dict[str, str]): Optional metadata to save along with model
+ state.
+ """
+ cache_data = [c.state for c in cache]
+ cache_info = [c.meta_state for c in cache]
+ cache_data = dict(tree_flatten(cache_data))
+ cache_classes = [type(c).__name__ for c in cache]
+ cache_metadata = [cache_info, metadata, cache_classes]
+ cache_metadata = dict(tree_flatten(cache_metadata))
+ mx.save_safetensors(file_name, cache_data, cache_metadata)
+
+
+def load_prompt_cache(file_name, return_metadata=False):
+ """
+ Load a prompt cache from a file.
+
+ Args:
+ file_name (str): The ``.safetensors`` file name.
+ return_metadata (bool): Whether or not to return metadata.
+ Default: ``False``.
+
+ Returns:
+ List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
+ the metadata if requested.
+ """
+ arrays, cache_metadata = mx.load(file_name, return_metadata=True)
+ arrays = tree_unflatten(list(arrays.items()))
+ cache_metadata = tree_unflatten(list(cache_metadata.items()))
+ info, metadata, classes = cache_metadata
+ cache = [globals()[c]() for c in classes]
+ for c, state, meta_state in zip(cache, arrays, info):
+ c.state = state
+ c.meta_state = meta_state
+ if return_metadata:
+ return cache, metadata
+ return cache
+
+
+def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
+ """
+ Trim the model's cache by the given number of tokens.
+
+ This function will trim the cache if possible (in-place) and return the
+ number of tokens that were trimmed.
+
+ Args:
+ cache (List[Any]): The model's cache.
+ num_tokens (int): The number of tokens to trim.
+
+ Returns:
+ (int): The number of tokens that were trimmed.
+ """
+ if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
+ return 0
+ return [c.trim(num_tokens) for c in cache][0]
+
+
+class _BaseCache:
+ @property
+ def state(self):
+ return []
+
+ @state.setter
+ def state(self, v):
+ if v is not None and v:
+ raise ValueError("This cache has no state but a state was set.")
+
+ @property
+ def meta_state(self):
+ return ""
+
+ @meta_state.setter
+ def meta_state(self, v):
+ if v is not None and v:
+ raise ValueError("This cache has no meta_state but a meta_state was set.")
+
+ def is_trimmable(self):
+ return False
+
+
+class KVCache(_BaseCache):
+ def __init__(self):
+ self.keys = None
+ self.values = None
+ self.offset = 0
+ self.step = 256
+
+ def update_and_fetch(self, keys, values):
+ prev = self.offset
+ if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
+ B, n_kv_heads, _, k_head_dim = keys.shape
+ v_head_dim = values.shape[3]
+ n_steps = (self.step + keys.shape[2] - 1) // self.step
+ k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
+ v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
+ new_k = mx.zeros(k_shape, keys.dtype)
+ new_v = mx.zeros(v_shape, values.dtype)
+ if self.keys is not None:
+ if prev % self.step != 0:
+ self.keys = self.keys[..., :prev, :]
+ self.values = self.values[..., :prev, :]
+ self.keys = mx.concatenate([self.keys, new_k], axis=2)
+ self.values = mx.concatenate([self.values, new_v], axis=2)
+ else:
+ self.keys, self.values = new_k, new_v
+
+ self.offset += keys.shape[2]
+ self.keys[..., prev : self.offset, :] = keys
+ self.values[..., prev : self.offset, :] = values
+ return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+
+ @property
+ def state(self):
+ if self.offset == self.keys.shape[2]:
+ return self.keys, self.values
+ else:
+ return (
+ self.keys[..., : self.offset, :],
+ self.values[..., : self.offset, :],
+ )
+
+ @state.setter
+ def state(self, v):
+ self.keys, self.values = v
+ self.offset = self.keys.shape[2]
+
+ def is_trimmable(self):
+ return True
+
+ def trim(self, n):
+ n = min(self.offset, n)
+ self.offset -= n
+ return n
+
+
+class RotatingKVCache(_BaseCache):
+
+ def __init__(self, max_size=None, keep=0, step=256):
+ self.keep = keep
+ self.keys = None
+ self.values = None
+ self.offset = 0
+ self.max_size = max_size
+ self.step = step
+ self._idx = 0
+
+ def _trim(self, trim_size, v, append=None):
+ to_cat = []
+ if trim_size > 0:
+ to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
+ else:
+ to_cat = [v]
+ if append is not None:
+ to_cat.append(append)
+ return mx.concatenate(to_cat, axis=2)
+
+ def _temporal_order(self, v):
+ """
+ Rearrange the cache into temporal order, slicing off the end if unused.
+ """
+ if self._idx == v.shape[2]:
+ return v
+ elif self._idx < self.offset:
+ return mx.concatenate(
+ [
+ v[..., : self.keep, :],
+ v[..., self._idx :, :],
+ v[..., self.keep : self._idx, :],
+ ],
+ axis=2,
+ )
+ else:
+ return v[..., : self._idx, :]
+
+ def _update_concat(self, keys, values):
+ if self.keys is None:
+ self.keys = keys
+ self.values = values
+ else:
+ # Put the keys/values in temporal order to
+ # preserve context
+ self.keys = self._temporal_order(self.keys)
+ self.values = self._temporal_order(self.values)
+
+ # The largest size is self.max_size + S - 1 to ensure
+ # every token gets at least self.max_size context
+ trim_size = self._idx - self.max_size + 1
+ self.keys = self._trim(trim_size, self.keys, keys)
+ self.values = self._trim(trim_size, self.values, values)
+ self.offset += keys.shape[2]
+ self._idx = self.keys.shape[2]
+ return self.keys, self.values
+
+ def _update_in_place(self, keys, values):
+ # May not have hit the max size yet, so potentially
+ # keep growing the cache
+ B, n_kv_heads, S, k_head_dim = keys.shape
+ prev = self.offset
+ if self.keys is None or (
+ prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
+ ):
+ v_head_dim = values.shape[3]
+ new_size = min(self.step, self.max_size - prev)
+ k_shape = (B, n_kv_heads, new_size, k_head_dim)
+ v_shape = (B, n_kv_heads, new_size, v_head_dim)
+ new_k = mx.zeros(k_shape, keys.dtype)
+ new_v = mx.zeros(v_shape, values.dtype)
+ if self.keys is not None:
+ self.keys = mx.concatenate([self.keys, new_k], axis=2)
+ self.values = mx.concatenate([self.values, new_v], axis=2)
+ else:
+ self.keys, self.values = new_k, new_v
+ self._idx = prev
+
+ # Trim if needed
+ trim_size = self.keys.shape[2] - self.max_size
+ if trim_size > 0:
+ self.keys = self._trim(trim_size, self.keys)
+ self.values = self._trim(trim_size, self.values)
+ self._idx = self.max_size
+
+ # Rotate
+ if self._idx == self.max_size:
+ self._idx = self.keep
+
+ # Assign
+ self.keys[..., self._idx : self._idx + S, :] = keys
+ self.values[..., self._idx : self._idx + S, :] = values
+ self.offset += S
+ self._idx += S
+
+ # If the buffer is not full, slice off the end
+ if self.offset < self.max_size:
+ return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+ return self.keys, self.values
+
+ def update_and_fetch(self, keys, values):
+ if keys.shape[2] == 1:
+ return self._update_in_place(keys, values)
+ return self._update_concat(keys, values)
+
+ @property
+ def state(self):
+ if self.offset < self.keys.shape[2]:
+ return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+ else:
+ return self.keys, self.values
+
+ @state.setter
+ def state(self, v):
+ self.keys, self.values = v
+
+ @property
+ def meta_state(self):
+ return tuple(
+ map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
+ )
+
+ @meta_state.setter
+ def meta_state(self, v):
+ self.keep, self.max_size, self.step, self.offset, self._idx = map(
+ int,
+ v,
+ )
+
+ def is_trimmable(self):
+ return self.offset < self.max_size
+
+ def trim(self, n):
+ n = min(self.offset, n)
+ self.offset -= n
+ self._idx -= n
+ return n
+
+
+class MambaCache(_BaseCache):
+ def __init__(self):
+ self.cache = [None, None]
+
+ def __setitem__(self, idx, value):
+ self.cache[idx] = value
+
+ def __getitem__(self, idx):
+ return self.cache[idx]
+
+ @property
+ def state(self):
+ return self.cache
+
+ @state.setter
+ def state(self, v):
+ self.cache = v
diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py
index cfcf2945..057c816d 100644
--- a/llms/mlx_lm/models/cohere.py
+++ b/llms/mlx_lm/models/cohere.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -69,7 +69,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
@@ -190,11 +190,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py
index f0214549..3b7e83d7 100644
--- a/llms/mlx_lm/models/dbrx.py
+++ b/llms/mlx_lm/models/dbrx.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -49,7 +49,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
@@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
@@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
@@ -249,11 +249,3 @@ class Model(nn.Module):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights
-
- @property
- def head_dim(self):
- return self.args.d_model // self.args.n_heads
-
- @property
- def n_kv_heads(self):
- return self.args.attn_config["kv_n_heads"]
diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py
index dcfa331c..03cb3b1a 100644
--- a/llms/mlx_lm/models/deepseek.py
+++ b/llms/mlx_lm/models/deepseek.py
@@ -1,10 +1,10 @@
from dataclasses import dataclass
-from typing import Dict, Optional
+from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
- hidden_size: int | None = None,
- intermediate_size: int | None = None,
+ hidden_size: Optional[int] = None,
+ intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
@@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
def __call__(
self,
x: mx.array,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -235,7 +235,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -256,11 +256,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py
index 602a9710..17d061a8 100644
--- a/llms/mlx_lm/models/deepseek_v2.py
+++ b/llms/mlx_lm/models/deepseek_v2.py
@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
- rope_scaling: Optional[Dict] = None
+ rope_scaling: Dict = None
attention_bias: bool = False
@@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module):
bias=config.attention_bias,
)
- if self.config.rope_scaling is not None:
- mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
- scaling_factor = self.config.rope_scaling["factor"]
- if mscale_all_dim:
- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
- self.scale = self.scale * mscale * mscale
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
@@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
def __call__(
self,
x: mx.array,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -395,7 +394,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -416,14 +415,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return (
- self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
- self.args.v_head_dim,
- )
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py
index c6150284..61de781e 100644
--- a/llms/mlx_lm/models/gemma.py
+++ b/llms/mlx_lm/models/gemma.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -173,11 +173,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.head_dim
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py
index 1d410a15..ccc327a8 100644
--- a/llms/mlx_lm/models/gemma2.py
+++ b/llms/mlx_lm/models/gemma2.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -64,7 +64,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
@@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
- r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
- r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
- mx.float32
- )
+ r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
@@ -200,11 +198,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.head_dim
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py
index 8a770936..97d9a8ff 100644
--- a/llms/mlx_lm/models/gpt2.py
+++ b/llms/mlx_lm/models/gpt2.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -46,7 +46,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
-
- @property
- def head_dim(self):
- return self.args.n_embd // self.args.n_head
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py
index 652eb9e4..068046ea 100644
--- a/llms/mlx_lm/models/gpt_bigcode.py
+++ b/llms/mlx_lm/models/gpt_bigcode.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -57,7 +57,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -184,11 +184,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
-
- @property
- def head_dim(self):
- return self.args.n_embd // self.args.n_head
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py
index c2aaa9ea..9f662491 100644
--- a/llms/mlx_lm/models/gpt_neox.py
+++ b/llms/mlx_lm/models/gpt_neox.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
@@ -214,11 +214,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py
index bcc0cf0c..5264cb57 100644
--- a/llms/mlx_lm/models/internlm2.py
+++ b/llms/mlx_lm/models/internlm2.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -116,7 +116,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py
index c4a947a5..7da6b333 100644
--- a/llms/mlx_lm/models/llama.py
+++ b/llms/mlx_lm/models/llama.py
@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -171,7 +171,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -303,13 +303,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return (
- self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
- )
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py
index 26408426..d2740dc1 100644
--- a/llms/mlx_lm/models/mamba.py
+++ b/llms/mlx_lm/models/mamba.py
@@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
+from .cache import MambaCache
@dataclass
@@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
-class MambaCache:
- def __init__(self):
- self.cache = [None, None]
-
- def __setitem__(self, idx, value):
- self.cache[idx] = value
-
- def __getitem__(self, idx):
- return self.cache[idx]
-
- @property
- def state(self):
- return self.cache
-
-
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
@@ -223,7 +209,7 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
- def make_cache(self, batch_size: int = 1):
+ def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py
index df0670be..4ac3c3b4 100644
--- a/llms/mlx_lm/models/minicpm.py
+++ b/llms/mlx_lm/models/minicpm.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -85,7 +85,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
):
B, L, _ = x.shape
@@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
@@ -205,11 +205,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py
index 2db57752..20944fe3 100644
--- a/llms/mlx_lm/models/mixtral.py
+++ b/llms/mlx_lm/models/mixtral.py
@@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -215,11 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py
index ef55d1d7..3ea06e27 100644
--- a/llms/mlx_lm/models/nemotron.py
+++ b/llms/mlx_lm/models/nemotron.py
@@ -2,12 +2,12 @@
from dataclasses import dataclass
from functools import partial
-from typing import Dict, Optional, Union
+from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -94,7 +94,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -215,13 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return (
- self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
- )
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py
index 59849c96..3627df06 100644
--- a/llms/mlx_lm/models/olmo.py
+++ b/llms/mlx_lm/models/olmo.py
@@ -1,8 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
+import sys
from dataclasses import dataclass
-from sys import exit
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -13,7 +13,7 @@ try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
- exit(1)
+ sys.exit(1)
@dataclass
@@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
@@ -174,11 +174,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.transformer.blocks
-
- @property
- def head_dim(self):
- return self.args.d_model // self.args.n_heads
-
- @property
- def n_kv_heads(self):
- return self.args.n_heads
diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py
index 19d3c027..090e21c6 100644
--- a/llms/mlx_lm/models/openelm.py
+++ b/llms/mlx_lm/models/openelm.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -80,7 +80,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
@@ -218,11 +218,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.layers
-
- @property
- def head_dim(self):
- return self.args.head_dim
-
- @property
- def n_kv_heads(self):
- return self.args.num_kv_heads
diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py
index fd3fd709..56b383b2 100644
--- a/llms/mlx_lm/models/phi.py
+++ b/llms/mlx_lm/models/phi.py
@@ -162,19 +162,11 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
- cache: mx.array = None,
- ) -> Tuple[mx.array, mx.array]:
+ cache=None,
+ ) -> mx.array:
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py
index 112ade7d..9ef76f04 100644
--- a/llms/mlx_lm/models/phi3.py
+++ b/llms/mlx_lm/models/phi3.py
@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
@@ -84,7 +84,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -202,11 +202,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py
index 665dbc73..6b0759b4 100644
--- a/llms/mlx_lm/models/phi3small.py
+++ b/llms/mlx_lm/models/phi3small.py
@@ -3,12 +3,12 @@
import math
from dataclasses import dataclass
from functools import partial
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
- num_key_value_heads: Optional[int] = None
+ num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
- blocksparse_block_size: Tuple[int] = (64,)
+ blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@@ -61,7 +61,6 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
- assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@@ -161,7 +160,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -304,16 +303,8 @@ class Model(nn.Module):
def layers(self):
return self.model.layers
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py
index db6bd4b5..ca20a388 100644
--- a/llms/mlx_lm/models/phimoe.py
+++ b/llms/mlx_lm/models/phimoe.py
@@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
+ self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
@@ -208,11 +209,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py
index bb67615d..865d0d8e 100644
--- a/llms/mlx_lm/models/phixtral.py
+++ b/llms/mlx_lm/models/phixtral.py
@@ -168,8 +168,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
- cache: mx.array = None,
- ) -> Tuple[mx.array, mx.array]:
+ cache=None,
+ ) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
@@ -193,11 +193,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
-
- @property
- def head_dim(self):
- return self.args.model_dim // self.args.num_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_heads
diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py
index 5d2b7586..090922ae 100644
--- a/llms/mlx_lm/models/plamo.py
+++ b/llms/mlx_lm/models/plamo.py
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -62,8 +62,8 @@ class Attention(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
- ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
+ cache: Optional[Any] = None,
+ ) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
@@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
- cache: Optional[Tuple[mx.array, mx.array]] = None,
- ) -> Tuple[Any, ...]:
+ cache: Optional[Any] = None,
+ ):
# from LlamaDecoder
residual = hidden_states
@@ -169,8 +169,8 @@ class PlamoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
- cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
- ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
+ cache: Optional[Any] = None,
+ ) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
@@ -197,19 +197,11 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
- cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
- ) -> Tuple[mx.array, mx.array]:
+ cache: Optional[Any] = None,
+ ) -> mx.array:
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_attention_heads // self.args.n_shared_head
diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py
index 6d2c7bbf..2b69d5ec 100644
--- a/llms/mlx_lm/models/qwen.py
+++ b/llms/mlx_lm/models/qwen.py
@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -149,19 +148,11 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
- cache: mx.array = None,
- ) -> Tuple[mx.array, mx.array]:
+ cache=None,
+ ) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_attention_heads
diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py
index b3ce02a3..4e7858de 100644
--- a/llms/mlx_lm/models/qwen2.py
+++ b/llms/mlx_lm/models/qwen2.py
@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py
index ff7831f3..d199116f 100644
--- a/llms/mlx_lm/models/qwen2_moe.py
+++ b/llms/mlx_lm/models/qwen2_moe.py
@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py
index 34750ace..06a307a6 100644
--- a/llms/mlx_lm/models/recurrent_gemma.py
+++ b/llms/mlx_lm/models/recurrent_gemma.py
@@ -7,13 +7,13 @@ from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs
+from .base import BaseModelArgs, create_attention_mask
+from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
- hidden_size: int
attention_bias: bool
conv1d_width: int
hidden_size: int
@@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
self.block_types = self._block_types
-def create_window_causal_mask(N: int, window_size: int):
- inds = mx.arange(N)
- linds = inds[:, None]
- rinds = inds[None]
- mask = (linds < rinds) | (linds > rinds + window_size)
- return mask * -1e9
-
-
-class RecurrentCache:
-
- def __init__(self):
- self._cache = (None, None)
-
- def __getitem__(self, idx):
- return self._cache[idx]
-
- def update(self, conv_state, recurrent_state):
- self._cache = (conv_state, recurrent_state)
-
- def state(self):
- return self._cache
-
-
-class WindowKVCache:
-
- def __init__(self, window_size):
- self.keys = None
- self.values = None
- self.offset = 0
- self.window_size = window_size
-
- def update_and_fetch(self, keys, values):
- # TODO consider using rotating buffer here
- # especially for very long generations
- def _update(x, v):
- t = x.shape[2] - self.window_size
- if t > 0:
- x = x[..., t:, :]
- return mx.concatenate([x, v], axis=2)
-
- self.offset += keys.shape[2]
- if self.keys is None:
- self.keys = keys
- self.values = values
- else:
- self.keys = _update(self.keys, keys)
- self.values = _update(self.values, values)
- return self.keys, self.values
-
- def state(self):
- return self.keys, self.values
-
-
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -136,31 +83,22 @@ class Conv1d(nn.Module):
kernel_size: int,
):
super().__init__()
- self.weight = mx.zeros((kernel_size, channels))
+ self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
- w = self.weight.T[..., None]
- kw, groups = self.weight.shape
- if cache is not None:
- l = []
- # Pad the cache if needed
- if cache.shape[1] < kw - 1:
- l.append(
- mx.zeros(
- (x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
- )
- )
- l.extend([cache, x])
- x = mx.concatenate(l, axis=1)
- y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
- else:
- y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
+ B, L, C = x.shape
+ groups, K, _ = self.weight.shape
- # The cache is always kw - 1
- cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
+ if cache is not None:
+ x = mx.concatenate([cache, x], axis=1)
+ else:
+ x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
+
+ y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
- return y, cache
+
+ return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
@@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
# x branch.
x = self.linear_x(x)
if cache is None:
- conv_state, recurrent_state = (None, None)
- else:
- conv_state, recurrent_state = cache[0], cache[1]
- x, conv_state = self.conv_1d(
- x=x,
- cache=conv_state,
- )
- x, recurrent_state = self.rg_lru(
- x=x,
- cache=recurrent_state,
- )
- if cache is not None:
- cache.update(conv_state, recurrent_state)
+ cache = [None, None]
+ x, cache[0] = self.conv_1d(x=x, cache=cache[0])
+ x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
@@ -467,12 +395,14 @@ class Griffin(nn.Module):
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
- mask = None
- if x.shape[1] > 1:
- mask = create_window_causal_mask(
- x.shape[1], self.config.attention_window_size
- )
- mask = mask.astype(x.dtype)
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ for i, block in enumerate(self.layers):
+ if block.temporal_block_type != "recurrent":
+ mask_cache = [cache[i]]
+
+ mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -485,6 +415,7 @@ class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
+ self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
@@ -508,10 +439,9 @@ class Model(nn.Module):
return self.model.layers
def sanitize(self, weights):
- # Remove unused precomputed rotary freqs
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
- weights[k] = v.squeeze(1).T
+ weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
@@ -520,7 +450,7 @@ class Model(nn.Module):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
- cache.append(RecurrentCache())
+ cache.append(MambaCache())
else:
- cache.append(WindowKVCache(self.args.attention_window_size))
+ cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache
diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py
index b340de28..11202b02 100644
--- a/llms/mlx_lm/models/stablelm.py
+++ b/llms/mlx_lm/models/stablelm.py
@@ -2,7 +2,6 @@
import math
from dataclasses import dataclass
-from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -198,8 +197,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
- cache: mx.array = None,
- ) -> Tuple[mx.array, mx.array]:
+ cache=None,
+ ) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@@ -207,11 +206,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py
index 9cec0e39..ce0a2ec5 100644
--- a/llms/mlx_lm/models/starcoder2.py
+++ b/llms/mlx_lm/models/starcoder2.py
@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
-from .base import BaseModelArgs, KVCache, create_attention_mask
+from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -45,7 +45,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
- cache: Optional[KVCache] = None,
+ cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -164,11 +164,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
-
- @property
- def head_dim(self):
- return self.args.hidden_size // self.args.num_attention_heads
-
- @property
- def n_kv_heads(self):
- return self.args.num_key_value_heads
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 54a96457..8649fbe3 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
-from .models.base import KVCache, RotatingKVCache
+from .models import base, cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
-def make_kv_caches(
- model: nn.Module, max_kv_size: Optional[int] = None
-) -> List[Union[KVCache, RotatingKVCache]]:
- if hasattr(model, "make_cache"):
- return model.make_cache()
-
- kv_heads = (
- [model.n_kv_heads] * len(model.layers)
- if isinstance(model.n_kv_heads, int)
- else model.n_kv_heads
- )
- if max_kv_size is not None:
- return [
- RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
- for n in kv_heads
- ]
- else:
- return [KVCache(model.head_dim, n) for n in kv_heads]
-
-
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -155,7 +135,7 @@ def generate_step(
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
- cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
+ prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
@@ -180,6 +160,8 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
+ prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
+ provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
@@ -237,20 +219,13 @@ def generate_step(
tokens = None
# Create the KV cache for generation
- cache = make_kv_caches(model, max_kv_size)
-
- if cache_history is not None:
- if len(cache_history) != len(cache):
- raise ValueError("Wrong number of layers in the cache history")
-
- # Set the history in the cache objects and evaluate them to prepare for
- # generation.
- for c, h in zip(cache, cache_history):
- c.update_and_fetch(h[0], h[1])
- mx.eval([c.state for c in cache])
+ if prompt_cache is None:
+ prompt_cache = cache.make_prompt_cache(model, max_kv_size)
+ elif len(prompt_cache) != len(model.layers):
+ raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y):
- logits = model(y[None], cache=cache)
+ logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processor:
@@ -305,9 +280,9 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
- for (token, _), n in zip(
- generate_step(prompt_tokens, model, **kwargs),
+ for n, (token, _) in zip(
range(max_tokens),
+ generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
@@ -357,9 +332,9 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
- for (token, logprobs), n in zip(
- generate_step(prompt_tokens, model, **kwargs),
+ for n, (token, logprobs) in zip(
range(max_tokens),
+ generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
diff --git a/llms/setup.py b/llms/setup.py
index e2cfe0cd..1c696dc0 100644
--- a/llms/setup.py
+++ b/llms/setup.py
@@ -32,6 +32,7 @@ setup(
entry_points={
"console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
+ "mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",
diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py
index cd7e7fd0..1efde5ae 100644
--- a/llms/tests/test_models.py
+++ b/llms/tests/test_models.py
@@ -1,17 +1,15 @@
# Copyright © 2024 Apple Inc.
-
import unittest
import mlx.core as mx
from mlx.utils import tree_map
-from mlx_lm.models.base import KVCache, RotatingKVCache
-from mlx_lm.utils import make_kv_caches
+from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
- cache = KVCache(32, 4)
+ cache = KVCache()
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
@@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
- cache = RotatingKVCache(d, h, max_size=8, step=4)
+ cache = RotatingKVCache(max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
@@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
idx %= 8
# Try with nonzero keep
- cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
+ cache = RotatingKVCache(max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
@@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
if idx >= 8:
idx = 2
+ def test_rotating_kv_cache_chat_mode(self):
+ # Test that the rotating kv cache can handle
+ # alternating prompt/prefill with generation
+ d = 4
+ h = 2
+ cache = RotatingKVCache(max_size=18, step=4)
+
+ x = mx.random.uniform(shape=(1, h, 8, d))
+ k, v = cache.update_and_fetch(x, x)
+ self.assertEqual(k.shape[2], 8)
+ self.assertEqual(cache.offset, 8)
+
+ x = mx.random.uniform(shape=(1, h, 1, d))
+ k, v = cache.update_and_fetch(x, x)
+ self.assertEqual(k.shape[2], 9)
+ self.assertEqual(cache.offset, 9)
+ self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
+
+ x = mx.random.uniform(shape=(1, h, 2, d))
+ k, v = cache.update_and_fetch(x, x)
+ self.assertEqual(k.shape[2], 11)
+ self.assertEqual(cache.offset, 11)
+ self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
+
+ x = mx.random.uniform(shape=(1, h, 3, d))
+ k, v = cache.update_and_fetch(x, x)
+ self.assertEqual(k.shape[2], 14)
+ self.assertEqual(cache.offset, 14)
+ self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
+
+ x = mx.random.uniform(shape=(1, h, 6, d))
+ k, v = cache.update_and_fetch(x, x)
+ self.assertEqual(cache.offset, 20)
+ self.assertTrue(mx.allclose(x, k[..., -6:, :]))
+
+ x = mx.random.uniform(shape=(1, h, 2, d))
+ k, v = cache.update_and_fetch(x, x)
+ self.assertEqual(cache.offset, 22)
+ self.assertTrue(mx.allclose(x, k[..., -2:, :]))
+
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
@@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
- cache = make_kv_caches(model)
+ cache = make_prompt_cache(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
+ def test_deepseek(self):
+ from mlx_lm.models import deepseek
+
+ args = deepseek.ModelArgs(
+ model_type="deepseek",
+ vocab_size=1024,
+ hidden_size=128,
+ intermediate_size=256,
+ moe_intermediate_size=256,
+ num_hidden_layers=4,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ )
+ model = deepseek.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
+ def test_deepseek_v2(self):
+ from mlx_lm.models import deepseek_v2
+
+ args = deepseek_v2.ModelArgs(
+ model_type="deepseek_v2",
+ vocab_size=1024,
+ hidden_size=128,
+ intermediate_size=256,
+ moe_intermediate_size=256,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ kv_lora_rank=4,
+ q_lora_rank=4,
+ qk_rope_head_dim=32,
+ v_head_dim=16,
+ qk_nope_head_dim=32,
+ rope_scaling={
+ "beta_fast": 32,
+ "beta_slow": 1,
+ "factor": 40,
+ "mscale": 1.0,
+ "mscale_all_dim": 1.0,
+ "original_max_position_embeddings": 4096,
+ "type": "yarn",
+ },
+ )
+ model = deepseek_v2.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
+ def test_gemma2(self):
+ from mlx_lm.models import gemma2
+
+ args = gemma2.ModelArgs(
+ model_type="gemma2",
+ hidden_size=128,
+ num_hidden_layers=4,
+ intermediate_size=256,
+ num_attention_heads=2,
+ head_dim=32,
+ rms_norm_eps=1e-4,
+ vocab_size=1024,
+ num_key_value_heads=2,
+ )
+ model = gemma2.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
+ def test_gpt_bigcode(self):
+ from mlx_lm.models import gpt_bigcode
+
+ args = gpt_bigcode.ModelArgs(
+ model_type="gpt_bigcode",
+ n_embd=128,
+ n_layer=128,
+ n_inner=256,
+ n_head=4,
+ n_positions=1000,
+ layer_norm_epsilon=1e-5,
+ vocab_size=1024,
+ )
+ model = gpt_bigcode.Model(args)
+ self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
+
+ def test_nemotron(self):
+ from mlx_lm.models import nemotron
+
+ args = nemotron.ModelArgs(
+ model_type="nemotron",
+ hidden_size=128,
+ hidden_act="gelu",
+ num_hidden_layers=4,
+ intermediate_size=256,
+ num_attention_heads=4,
+ norm_eps=1e-5,
+ vocab_size=1024,
+ num_key_value_heads=2,
+ )
+ model = nemotron.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
+ def test_phi3small(self):
+ from mlx_lm.models import phi3small
+
+ args = phi3small.ModelArgs(
+ model_type="phi3small",
+ hidden_size=128,
+ dense_attention_every_n_layers=2,
+ ff_intermediate_size=256,
+ gegelu_limit=1.0,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ layer_norm_epsilon=1e-4,
+ vocab_size=1000,
+ )
+ model = phi3small.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
+ def test_phimoe(self):
+ from mlx_lm.models import phimoe
+
+ args = phimoe.ModelArgs(
+ model_type="phimoe",
+ vocab_size=320,
+ hidden_size=128,
+ intermediate_size=256,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ rope_scaling={
+ "long_factor": [1.0] * 16,
+ "long_mscale": 1.243163121016122,
+ "original_max_position_embeddings": 4096,
+ "short_factor": [1.0] * 16,
+ "short_mscale": 1.243163121016122,
+ "type": "longrope",
+ },
+ )
+ model = phimoe.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
+ def test_recurrent_gemma(self):
+ from mlx_lm.models import recurrent_gemma
+
+ args = recurrent_gemma.ModelArgs(
+ model_type="recurrent_gemma",
+ hidden_size=128,
+ attention_bias=False,
+ conv1d_width=3,
+ intermediate_size=256,
+ logits_soft_cap=1.0,
+ num_attention_heads=4,
+ num_hidden_layers=4,
+ num_key_value_heads=2,
+ rms_norm_eps=1e-4,
+ rope_theta=1000,
+ attention_window_size=1024,
+ vocab_size=1000,
+ block_types=["recurrent", "recurrent", "attention"],
+ )
+ model = recurrent_gemma.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py
new file mode 100644
index 00000000..3c1ef49b
--- /dev/null
+++ b/llms/tests/test_prompt_cache.py
@@ -0,0 +1,220 @@
+# Copyright © 2024 Apple Inc.
+
+import os
+import tempfile
+import unittest
+
+import mlx.core as mx
+from mlx_lm.models.cache import (
+ KVCache,
+ MambaCache,
+ RotatingKVCache,
+ load_prompt_cache,
+ make_prompt_cache,
+ save_prompt_cache,
+ trim_prompt_cache,
+)
+from mlx_lm.utils import generate_step, load
+
+HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
+
+
+class TestPromptCache(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_dir_fid = tempfile.TemporaryDirectory()
+ cls.test_dir = cls.test_dir_fid.name
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.test_dir_fid.cleanup()
+
+ def test_save_load(self):
+ cache = [KVCache() for _ in range(4)]
+ for c in cache:
+ x = mx.random.uniform(shape=(1, 8, 10, 4))
+ c.update_and_fetch(x, x)
+ cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
+ save_prompt_cache(cache_file, cache)
+ loaded_cache = load_prompt_cache(cache_file)
+ self.assertTrue(len(cache), len(loaded_cache))
+ for c, lc in zip(cache, loaded_cache):
+ self.assertEqual(c.offset, lc.offset)
+ self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
+ self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
+
+ # Test with metadata
+ cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
+ metadata = {"a": "b", "c": "d"}
+ save_prompt_cache(cache_file, cache, metadata)
+ _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
+ self.assertEqual(metadata, loaded_metadata)
+
+ def test_save_load_rotating_cache(self):
+ cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
+
+ # Test with rotating cache
+ cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
+ for c in cache:
+ x = mx.random.uniform(shape=(1, 8, 10, 4))
+ c.update_and_fetch(x, x)
+
+ save_prompt_cache(cache_file, cache)
+ loaded_cache = load_prompt_cache(cache_file)
+ self.assertTrue(len(cache), len(loaded_cache))
+ for c, lc in zip(cache, loaded_cache):
+ self.assertEqual(c.offset, lc.offset)
+ self.assertEqual(c.keep, lc.keep)
+ self.assertEqual(c.max_size, lc.max_size)
+ self.assertEqual(c.step, lc.step)
+ self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
+ self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
+
+ # Do a couple single token updates to get a rotation
+ for _ in range(2):
+ for c in cache:
+ x = mx.random.uniform(shape=(1, 8, 1, 4))
+ c.update_and_fetch(x, x)
+
+ save_prompt_cache(cache_file, cache)
+ loaded_cache = load_prompt_cache(cache_file)
+
+ for c, lc in zip(cache, loaded_cache):
+ x = mx.random.uniform(shape=(1, 8, 1, 4))
+ k, v = c.update_and_fetch(x, x)
+ lk, lv = lc.update_and_fetch(x, x)
+ self.assertEqual(c.offset, lc.offset)
+ self.assertTrue(mx.array_equal(k, lk))
+ self.assertTrue(mx.array_equal(v, lv))
+
+ def test_save_load_mixed_cache(self):
+ cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
+
+ cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
+ for c in cache:
+ if isinstance(c, MambaCache):
+ c[0] = mx.random.uniform(shape=(4, 4, 4))
+ c[1] = mx.random.uniform(shape=(4, 4, 4))
+ else:
+ x = mx.random.uniform(shape=(4, 4, 7, 4))
+ y = mx.random.uniform(shape=(4, 4, 7, 4))
+ c.update_and_fetch(x, y)
+
+ save_prompt_cache(cache_file, cache)
+ loaded_cache = load_prompt_cache(cache_file)
+ for c, lc in zip(cache, loaded_cache):
+ if isinstance(c, MambaCache):
+ self.assertTrue(mx.array_equal(c[0], lc[0]))
+ self.assertTrue(mx.array_equal(c[1], lc[1]))
+ else:
+ x = mx.random.uniform(shape=(4, 4, 1, 4))
+ y = mx.random.uniform(shape=(4, 4, 1, 4))
+ k, v = c.update_and_fetch(x, y)
+ lk, lv = lc.update_and_fetch(x, y)
+ self.assertEqual(c.offset, lc.offset)
+ self.assertTrue(mx.array_equal(k, lk))
+ self.assertTrue(mx.array_equal(v, lv))
+
+ def test_cache_with_generate(self):
+ model, tokenizer = load(HF_MODEL_PATH)
+ prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
+ results = zip(range(4), generate_step(prompt, model))
+ toks, all_logits = zip(*(r[1] for r in results))
+
+ prompt_cache = make_prompt_cache(model)
+ i = 0
+ for _, (tok, logits) in zip(
+ range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
+ ):
+ self.assertEqual(tok, toks[i])
+ self.assertTrue(mx.allclose(logits, all_logits[i]))
+ i += 1
+
+ for _, (tok, logits) in zip(
+ range(1),
+ generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
+ ):
+ i += 1
+ self.assertEqual(tok, toks[i])
+ self.assertTrue(mx.allclose(logits, all_logits[i]))
+
+ def test_trim_cache(self):
+ cache = [KVCache() for _ in range(2)]
+ for c in cache:
+ x = mx.random.uniform(shape=(1, 8, 10, 4))
+ c.update_and_fetch(x, x)
+
+ # Trim
+ num_trimmed = trim_prompt_cache(cache, 7)
+ self.assertEqual(num_trimmed, 7)
+
+ # Trim more tokens than remain
+ num_trimmed = trim_prompt_cache(cache, 4)
+ self.assertEqual(num_trimmed, 3)
+
+ # Can't trim mamba cache
+ cache = [MambaCache() for _ in range(2)]
+ for c in cache:
+ c.state = mx.zeros((5, 5))
+ num_trimmed = trim_prompt_cache(cache, 7)
+ self.assertEqual(num_trimmed, 0)
+
+ # All cache's have to be trimmable
+ cache = [MambaCache(), KVCache()]
+ cache[0].state = mx.zeros((5, 5))
+ x = mx.random.uniform(shape=(1, 8, 10, 4))
+ cache[1].update_and_fetch(x, x)
+ num_trimmed = trim_prompt_cache(cache, 1)
+ self.assertEqual(num_trimmed, 0)
+
+ cache = [RotatingKVCache(max_size=6) for _ in range(2)]
+ for c in cache:
+ x = mx.random.uniform(shape=(1, 8, 5, 4))
+ c.update_and_fetch(x, x)
+
+ num_trimmed = trim_prompt_cache(cache, 4)
+ self.assertEqual(num_trimmed, 4)
+
+ # Can't trim fixed-size KV cache after processing
+ # more than max_kv_size tokens
+ for c in cache:
+ x = mx.random.uniform(shape=(1, 8, 10, 4))
+ c.update_and_fetch(x, x)
+
+ num_trimmed = trim_prompt_cache(cache, 4)
+ self.assertEqual(num_trimmed, 0)
+
+ def test_trim_cache_with_generate(self):
+ model, tokenizer = load(HF_MODEL_PATH)
+ prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
+
+ prompt_cache = make_prompt_cache(model)
+
+ # Generate one token so we process the full prompt
+ last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
+ last_tok = mx.array([last_tok])
+
+ # Generate two more tokens
+ results = zip(
+ range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
+ )
+ toks, all_logits = zip(*(r[1] for r in results))
+
+ # To get back to the cache just after processing the prompt,
+ # trim by 3 tokens
+ trim_prompt_cache(prompt_cache, 3)
+
+ # Generate the same thing again
+ results = zip(
+ range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
+ )
+ second_toks, second_all_logits = zip(*(r[1] for r in results))
+ self.assertEqual(toks, second_toks)
+ self.assertTrue(
+ all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
From b7373cb44f2728983ffb99a3d21d61f73230a41e Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Wed, 9 Oct 2024 11:09:36 -0700
Subject: [PATCH 044/188] fix long prompt generations (#1023)
---
llms/mlx_lm/utils.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 8649fbe3..cfbcf29e 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -239,8 +239,8 @@ def generate_step(
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
- model(y[:prefill_step_size][None], cache=cache)
- mx.eval([c.state for c in cache])
+ model(y[:prefill_step_size][None], cache=prompt_cache)
+ mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)
From 4360e7ccec3d2dd1a2ac96509c74afcaa5e80a95 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Wed, 9 Oct 2024 16:48:32 -0700
Subject: [PATCH 045/188] clear cache during prompt processing (#1027)
---
llms/mlx_lm/utils.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index cfbcf29e..1e07546e 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -242,6 +242,7 @@ def generate_step(
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
+ mx.metal.clear_cache()
y, logprobs = _step(y)
From d72fdeb4eeb0cf65f146ad3c3fbc48a2cd23974b Mon Sep 17 00:00:00 2001
From: Alex Barron
Date: Fri, 11 Oct 2024 10:16:20 -0700
Subject: [PATCH 046/188] MusicGen (#1020)
* Add MusicGen model
* add benchmarks
* change to from_pretrained
* symlinks
* add readme and requirements
* fix readme
* readme
---
encodec/README.md | 7 +-
encodec/benchmarks/bench_mx.py | 5 +-
encodec/convert.py | 1 -
encodec/encodec.py | 72 ++++++-
encodec/example.py | 6 +-
encodec/test.py | 11 +-
encodec/utils.py | 77 -------
musicgen/README.md | 31 +++
musicgen/benchmarks/bench_mx.py | 28 +++
musicgen/benchmarks/bench_pt.py | 31 +++
musicgen/encodec.py | 1 +
musicgen/generate.py | 23 ++
musicgen/musicgen.py | 358 ++++++++++++++++++++++++++++++++
musicgen/requirements.txt | 6 +
musicgen/t5.py | 1 +
musicgen/utils.py | 15 ++
t5/README.md | 40 ++--
t5/convert.py | 75 -------
t5/t5.py | 179 +++++++++++-----
19 files changed, 722 insertions(+), 245 deletions(-)
create mode 100644 musicgen/README.md
create mode 100644 musicgen/benchmarks/bench_mx.py
create mode 100644 musicgen/benchmarks/bench_pt.py
create mode 120000 musicgen/encodec.py
create mode 100644 musicgen/generate.py
create mode 100644 musicgen/musicgen.py
create mode 100644 musicgen/requirements.txt
create mode 120000 musicgen/t5.py
create mode 100644 musicgen/utils.py
delete mode 100644 t5/convert.py
diff --git a/encodec/README.md b/encodec/README.md
index 3ab2793c..a3b948bf 100644
--- a/encodec/README.md
+++ b/encodec/README.md
@@ -33,13 +33,14 @@ An example using the model:
```python
import mlx.core as mx
-from utils import load, load_audio, save_audio
+from encodec import EncodecModel
+from utils import load_audio, save_audio
# Load the 48 KHz model and preprocessor.
-model, processor = load("mlx-community/encodec-48khz-float32")
+model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
-audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
+audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
diff --git a/encodec/benchmarks/bench_mx.py b/encodec/benchmarks/bench_mx.py
index 2acd4b75..61ddaae8 100644
--- a/encodec/benchmarks/bench_mx.py
+++ b/encodec/benchmarks/bench_mx.py
@@ -3,9 +3,10 @@
import time
import mlx.core as mx
-from utils import load
-model, processor = load("mlx-community/encodec-48khz-float32")
+from encodec import EncodecModel
+
+model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)
diff --git a/encodec/convert.py b/encodec/convert.py
index 13bd31a6..5c5f7e22 100644
--- a/encodec/convert.py
+++ b/encodec/convert.py
@@ -10,7 +10,6 @@ from typing import Any, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
-from mlx.utils import tree_flatten
import encodec
diff --git a/encodec/encodec.py b/encodec/encodec.py
index 3ef47369..4b85dfdd 100644
--- a/encodec/encodec.py
+++ b/encodec/encodec.py
@@ -1,7 +1,10 @@
# Copyright © 2024 Apple Inc.
+import functools
+import json
import math
-from dataclasses import dataclass
+from pathlib import Path
+from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
import mlx.core as mx
@@ -669,3 +672,70 @@ class EncodecModel(nn.Module):
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values
+
+ @classmethod
+ def from_pretrained(cls, path_or_repo: str):
+ from huggingface_hub import snapshot_download
+
+ path = Path(path_or_repo)
+ if not path.exists():
+ path = Path(
+ snapshot_download(
+ repo_id=path_or_repo,
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
+ )
+ )
+
+ with open(path / "config.json", "r") as f:
+ config = SimpleNamespace(**json.load(f))
+
+ model = EncodecModel(config)
+ model.load_weights(str(path / "model.safetensors"))
+ processor = functools.partial(
+ preprocess_audio,
+ sampling_rate=config.sampling_rate,
+ chunk_length=model.chunk_length,
+ chunk_stride=model.chunk_stride,
+ )
+ mx.eval(model)
+ return model, processor
+
+
+def preprocess_audio(
+ raw_audio: Union[mx.array, List[mx.array]],
+ sampling_rate: int = 24000,
+ chunk_length: Optional[int] = None,
+ chunk_stride: Optional[int] = None,
+):
+ r"""
+ Prepare inputs for the EnCodec model.
+
+ Args:
+ raw_audio (mx.array or List[mx.array]): The sequence or batch of
+ sequences to be processed.
+ sampling_rate (int): The sampling rate at which the audio waveform
+ should be digitalized.
+ chunk_length (int, optional): The model's chunk length.
+ chunk_stride (int, optional): The model's chunk stride.
+ """
+ if not isinstance(raw_audio, list):
+ raw_audio = [raw_audio]
+
+ raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
+
+ max_length = max(array.shape[0] for array in raw_audio)
+ if chunk_length is not None:
+ max_length += chunk_length - (max_length % chunk_stride)
+
+ inputs = []
+ masks = []
+ for x in raw_audio:
+ length = x.shape[0]
+ mask = mx.ones((length,), dtype=mx.bool_)
+ difference = max_length - length
+ if difference > 0:
+ mask = mx.pad(mask, (0, difference))
+ x = mx.pad(x, ((0, difference), (0, 0)))
+ inputs.append(x)
+ masks.append(mask)
+ return mx.stack(inputs), mx.stack(masks)
diff --git a/encodec/example.py b/encodec/example.py
index 97b311a1..15ea476c 100644
--- a/encodec/example.py
+++ b/encodec/example.py
@@ -1,10 +1,12 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
-from utils import load, load_audio, save_audio
+from utils import load_audio, save_audio
+
+from encodec import EncodecModel
# Load the 48 KHz model and preprocessor.
-model, processor = load("mlx-community/encodec-48khz-float32")
+model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
diff --git a/encodec/test.py b/encodec/test.py
index ffc23505..ae565c29 100644
--- a/encodec/test.py
+++ b/encodec/test.py
@@ -3,9 +3,10 @@
import mlx.core as mx
import numpy as np
import torch
-from datasets import Audio, load_dataset
-from transformers import AutoProcessor, EncodecModel
-from utils import load, load_audio, preprocess_audio
+from transformers import AutoProcessor
+from transformers import EncodecModel as PTEncodecModel
+
+from encodec import EncodecModel, preprocess_audio
def compare_processors():
@@ -30,8 +31,8 @@ def compare_processors():
def compare_models():
- pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
- mx_model, _ = load("mlx-community/encodec-48khz-float32")
+ pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
+ mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560
diff --git a/encodec/utils.py b/encodec/utils.py
index 18b3f063..b429ed83 100644
--- a/encodec/utils.py
+++ b/encodec/utils.py
@@ -1,16 +1,7 @@
# Copyright © 2024 Apple Inc.
-import functools
-import json
-from pathlib import Path
-from types import SimpleNamespace
-from typing import List, Optional, Union
-
import mlx.core as mx
import numpy as np
-from huggingface_hub import snapshot_download
-
-import encodec
def save_audio(file: str, audio: mx.array, sampling_rate: int):
@@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
-
-
-def preprocess_audio(
- raw_audio: Union[mx.array, List[mx.array]],
- sampling_rate: int = 24000,
- chunk_length: Optional[int] = None,
- chunk_stride: Optional[int] = None,
-):
- r"""
- Prepare inputs for the EnCodec model.
-
- Args:
- raw_audio (mx.array or List[mx.array]): The sequence or batch of
- sequences to be processed.
- sampling_rate (int): The sampling rate at which the audio waveform
- should be digitalized.
- chunk_length (int, optional): The model's chunk length.
- chunk_stride (int, optional): The model's chunk stride.
- """
- if not isinstance(raw_audio, list):
- raw_audio = [raw_audio]
-
- raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
-
- max_length = max(array.shape[0] for array in raw_audio)
- if chunk_length is not None:
- max_length += chunk_length - (max_length % chunk_stride)
-
- inputs = []
- masks = []
- for x in raw_audio:
- length = x.shape[0]
- mask = mx.ones((length,), dtype=mx.bool_)
- difference = max_length - length
- if difference > 0:
- mask = mx.pad(mask, (0, difference))
- x = mx.pad(x, ((0, difference), (0, 0)))
- inputs.append(x)
- masks.append(mask)
- return mx.stack(inputs), mx.stack(masks)
-
-
-def load(path_or_repo):
- """
- Load the model and audo preprocessor.
- """
- path = Path(path_or_repo)
- if not path.exists():
- path = Path(
- snapshot_download(
- repo_id=path_or_repo,
- allow_patterns=["*.json", "*.safetensors", "*.model"],
- )
- )
-
- with open(path / "config.json", "r") as f:
- config = SimpleNamespace(**json.load(f))
-
- model = encodec.EncodecModel(config)
- model.load_weights(str(path / "model.safetensors"))
- processor = functools.partial(
- preprocess_audio,
- sampling_rate=config.sampling_rate,
- chunk_length=model.chunk_length,
- chunk_stride=model.chunk_stride,
- )
- mx.eval(model)
- return model, processor
diff --git a/musicgen/README.md b/musicgen/README.md
new file mode 100644
index 00000000..baef4fad
--- /dev/null
+++ b/musicgen/README.md
@@ -0,0 +1,31 @@
+# MusicGen
+
+An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate
+music from text descriptions.
+
+### Setup
+
+Install the requirements:
+
+```
+pip install -r requirements.txt
+```
+
+### Example
+
+An example using the model:
+
+```python
+import mlx.core as mx
+from music_gen import MusicGen
+from utils import save_audio
+
+model = MusicGen.from_pretrained("facebook/musicgen-medium")
+
+audio = model.generate("happy rock")
+
+save_audio("out.wav", audio, model.sampling_rate)
+```
+
+[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and
+ [code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details.
diff --git a/musicgen/benchmarks/bench_mx.py b/musicgen/benchmarks/bench_mx.py
new file mode 100644
index 00000000..b669597b
--- /dev/null
+++ b/musicgen/benchmarks/bench_mx.py
@@ -0,0 +1,28 @@
+# Copyright © 2024 Apple Inc.
+
+import sys
+import time
+from pathlib import Path
+
+import mlx.core as mx
+
+cur_path = Path(__file__).parents[1].resolve()
+sys.path.append(str(cur_path))
+
+from musicgen import MusicGen
+
+text = "folk ballad"
+model = MusicGen.from_pretrained("facebook/musicgen-medium")
+
+max_steps = 100
+
+audio = model.generate(text, max_steps=10)
+mx.eval(audio)
+
+tic = time.time()
+audio = model.generate(text, max_steps=max_steps)
+mx.eval(audio)
+toc = time.time()
+
+ms = 1000 * (toc - tic) / max_steps
+print(f"Time (ms) per step: {ms:.3f}")
diff --git a/musicgen/benchmarks/bench_pt.py b/musicgen/benchmarks/bench_pt.py
new file mode 100644
index 00000000..de01aa66
--- /dev/null
+++ b/musicgen/benchmarks/bench_pt.py
@@ -0,0 +1,31 @@
+# Copyright © 2024 Apple Inc.
+
+import time
+
+import torch
+from transformers import AutoProcessor, MusicgenForConditionalGeneration
+
+model_name = "facebook/musicgen-medium"
+processor = AutoProcessor.from_pretrained(model_name)
+model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
+
+inputs = processor(
+ text=["folk ballad"],
+ padding=True,
+ return_tensors="pt",
+)
+inputs["input_ids"] = inputs["input_ids"].to("mps")
+inputs["attention_mask"] = inputs["attention_mask"].to("mps")
+
+# warmup
+audio_values = model.generate(**inputs, max_new_tokens=10)
+torch.mps.synchronize()
+
+max_steps = 100
+tic = time.time()
+audio_values = model.generate(**inputs, max_new_tokens=max_steps)
+torch.mps.synchronize()
+toc = time.time()
+
+ms = 1000 * (toc - tic) / max_steps
+print(f"Time (ms) per step: {ms:.3f}")
diff --git a/musicgen/encodec.py b/musicgen/encodec.py
new file mode 120000
index 00000000..8eb278a7
--- /dev/null
+++ b/musicgen/encodec.py
@@ -0,0 +1 @@
+../encodec/encodec.py
\ No newline at end of file
diff --git a/musicgen/generate.py b/musicgen/generate.py
new file mode 100644
index 00000000..5a6b7804
--- /dev/null
+++ b/musicgen/generate.py
@@ -0,0 +1,23 @@
+# Copyright © 2024 Apple Inc.
+
+import argparse
+
+from utils import save_audio
+
+from musicgen import MusicGen
+
+
+def main(text: str, output_path: str, model_name: str, max_steps: int):
+ model = MusicGen.from_pretrained(model_name)
+ audio = model.generate(text, max_steps=max_steps)
+ save_audio(output_path, audio, model.sampling_rate)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=False, default="facebook/musicgen-medium")
+ parser.add_argument("--text", required=False, default="happy rock")
+ parser.add_argument("--output-path", required=False, default="0.wav")
+ parser.add_argument("--max-steps", required=False, default=500, type=int)
+ args = parser.parse_args()
+ main(args.text, args.output_path, args.model, args.max_steps)
diff --git a/musicgen/musicgen.py b/musicgen/musicgen.py
new file mode 100644
index 00000000..a2d021a5
--- /dev/null
+++ b/musicgen/musicgen.py
@@ -0,0 +1,358 @@
+# Copyright © 2024 Apple Inc.
+
+import json
+from functools import partial
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+from tqdm import tqdm
+
+from encodec import EncodecModel
+from t5 import T5
+
+
+class TextConditioner(nn.Module):
+ def __init__(self, t5_name, input_dim, output_dim):
+ super().__init__()
+ self._t5, self.tokenizer = T5.from_pretrained(t5_name)
+ self.output_proj = nn.Linear(input_dim, output_dim)
+
+ def __call__(self, text):
+ x = self.tokenizer.encode(text)
+ x = self._t5.encode(x)
+ return self.output_proj(x)
+
+
+class KVCache:
+ def __init__(self, head_dim, n_kv_heads):
+ self.n_kv_heads = n_kv_heads
+ if isinstance(head_dim, int):
+ self.k_head_dim = self.v_head_dim = head_dim
+ elif isinstance(head_dim, tuple) and len(head_dim) == 2:
+ self.k_head_dim, self.v_head_dim = head_dim
+ else:
+ raise ValueError("head_dim must be an int or a tuple of two ints")
+ self.keys = None
+ self.values = None
+ self.offset = 0
+ self.step = 256
+
+ def update_and_fetch(self, keys, values):
+ prev = self.offset
+ if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
+ B = keys.shape[0]
+ n_steps = (self.step + keys.shape[2] - 1) // self.step
+ k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
+ v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
+ new_k = mx.zeros(k_shape, keys.dtype)
+ new_v = mx.zeros(v_shape, values.dtype)
+ if self.keys is not None:
+ if prev % self.step != 0:
+ self.keys = self.keys[..., :prev, :]
+ self.values = self.values[..., :prev, :]
+ self.keys = mx.concatenate([self.keys, new_k], axis=2)
+ self.values = mx.concatenate([self.values, new_v], axis=2)
+ else:
+ self.keys, self.values = new_k, new_v
+
+ self.offset += keys.shape[2]
+ self.keys[..., prev : self.offset, :] = keys
+ self.values[..., prev : self.offset, :] = values
+ return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
+
+ @property
+ def state(self):
+ return self.keys, self.values
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, dim, n_heads):
+ super().__init__()
+
+ self.n_heads = n_heads
+
+ head_dim = dim // n_heads
+
+ self.scale = head_dim**-0.5
+
+ self.q_proj = nn.Linear(dim, dim, bias=False)
+ self.k_proj = nn.Linear(dim, dim, bias=False)
+ self.v_proj = nn.Linear(dim, dim, bias=False)
+ self.out_proj = nn.Linear(dim, dim, bias=False)
+
+ def __call__(
+ self,
+ queries: mx.array,
+ keys: mx.array,
+ values: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ B, L_q, D = queries.shape
+ L_k = keys.shape[1]
+
+ queries, keys, values = (
+ self.q_proj(queries),
+ self.k_proj(keys),
+ self.v_proj(values),
+ )
+
+ # Prepare the queries, keys and values for the attention computation
+ queries = queries.reshape(B, L_q, self.n_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3)
+
+ if cache is not None:
+ keys, values = cache.update_and_fetch(keys, values)
+
+ output = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L_q, -1)
+ return self.out_proj(output)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_attention_heads = config.decoder.num_attention_heads
+ self.hidden_size = config.decoder.hidden_size
+ self.self_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads)
+ self.cross_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads)
+ self.linear1 = nn.Linear(self.hidden_size, config.decoder.ffn_dim, bias=False)
+ self.linear2 = nn.Linear(config.decoder.ffn_dim, self.hidden_size, bias=False)
+
+ self.norm1 = nn.LayerNorm(self.hidden_size, eps=1e-5)
+ self.norm_cross = nn.LayerNorm(self.hidden_size, eps=1e-5)
+ self.norm2 = nn.LayerNorm(self.hidden_size, eps=1e-5)
+
+ def __call__(
+ self,
+ x: mx.array,
+ conditioning: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[KVCache] = None,
+ ) -> mx.array:
+ xn = self.norm1(x)
+ x += self.self_attn(xn, xn, xn, mask, cache)
+ xn = self.norm_cross(x)
+ x += self.cross_attn(xn, conditioning, conditioning, mask)
+ xn = self.norm2(x)
+ x += self.linear2(nn.gelu(self.linear1(xn)))
+ return x
+
+
+@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
+def top_k_sampling(
+ logits: mx.array, top_k: float, temperature: float, axis: int = -1
+) -> mx.array:
+ """
+ Apply top-k sampling to logits.
+
+ Args:
+ logits: The logits from the model's output.
+ top_k: Sample from the top k logits.
+ temperature: Temperature parameter for softmax distribution reshaping.
+ axis: Axis along which to sample.
+ Returns:
+ token selected based on the top-k criterion.
+ """
+ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
+ probs = mx.softmax(logits * (1 / temperature), axis=axis)
+
+ # sort probs in ascending order
+ sorted_indices = mx.argsort(probs, axis=axis)
+ sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)
+ prob_threshold = mx.take(sorted_probs, mx.array(-top_k), axis=axis)
+
+ # select the top K tokens in probability
+ top_probs = mx.where(
+ sorted_probs > prob_threshold,
+ sorted_probs,
+ 0,
+ )
+
+ sorted_token = mx.random.categorical(mx.log(top_probs), axis=axis)
+ token = mx.take_along_axis(
+ sorted_indices, mx.expand_dims(sorted_token, axis), axis=axis
+ )
+
+ return token
+
+
+def create_sin_embedding(positions: mx.array, dim: int, max_period: float = 10000):
+ assert dim % 2 == 0
+ half_dim = dim // 2
+ adim = mx.arange(half_dim).reshape(1, 1, -1)
+ phase = positions / (max_period ** (adim / (half_dim - 1)))
+ return mx.concatenate([mx.cos(phase), mx.sin(phase)], axis=-1)
+
+
+class MusicGen(nn.Module):
+ def __init__(self, config):
+ self.num_codebooks = config.decoder.num_codebooks
+ self.codebook_size = config.audio_encoder.codebook_size
+ self.bos_token_id = config.decoder.bos_token_id
+ self.hidden_size = config.decoder.hidden_size
+ self.num_attention_heads = config.decoder.num_attention_heads
+ self.sampling_rate = config.audio_encoder.sampling_rate
+
+ self.text_conditioner = TextConditioner(
+ config.text_encoder._name_or_path,
+ config.text_encoder.d_model,
+ self.hidden_size,
+ )
+ self.emb = [
+ nn.Embedding(self.codebook_size + 1, self.hidden_size)
+ for _ in range(self.num_codebooks)
+ ]
+ self.layers = [
+ TransformerBlock(config) for _ in range(config.decoder.num_hidden_layers)
+ ]
+ self.out_norm = nn.LayerNorm(self.hidden_size, eps=1e-5)
+ self.linears = [
+ nn.Linear(self.hidden_size, self.codebook_size, bias=False)
+ for _ in range(self.num_codebooks)
+ ]
+ encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
+ encodec_name = encodec_name.replace("_", "-")
+ self._audio_decoder, _ = EncodecModel.from_pretrained(
+ f"mlx-community/{encodec_name}-float32"
+ )
+
+ def __call__(
+ self,
+ audio_tokens: mx.array,
+ conditioning: mx.array,
+ cache: list[KVCache] = None,
+ ):
+
+ if cache is None:
+ cache = [None] * len(self.layers)
+
+ x = sum([self.emb[k](audio_tokens[..., k]) for k in range(self.num_codebooks)])
+
+ offset = cache[0].offset if cache[0] is not None else 0
+ pos_emb = create_sin_embedding(offset, self.hidden_size)
+ x += pos_emb.astype(x.dtype)
+
+ for layer, c in zip(self.layers, cache):
+ x = layer(x, conditioning, cache=c)
+
+ x = self.out_norm(x)
+ x = mx.stack([self.linears[k](x) for k in range(self.num_codebooks)], axis=-1)
+ return x
+
+ def generate(
+ self,
+ text: str,
+ max_steps: int = 200,
+ top_k: int = 250,
+ temp: float = 1.0,
+ guidance_coef: float = 3.0,
+ ) -> mx.array:
+ """
+ Generates a waveform conditioned on `text`.
+
+ Args:
+ text (str): The text to condition generation on.
+ max_steps (int): Max steps to generate.
+ top_k (int): Top k used in sampling.
+ temp (float): Sampling softmax temperature.
+ guidance_coef (float): Classifier free guidance coefficent.
+ Used to combine conditional and unconditional logits.
+
+ Returns:
+ An mx.array of audio samples of shape ``(num_samples,)``.
+ """
+ # Assuming no audio prompt we start with all bos token for the codebooks
+ audio_shape = (1, max_steps + 1, self.num_codebooks)
+ audio_seq = mx.full(audio_shape, self.bos_token_id)
+
+ text_tokens = self.text_conditioner(text)
+ # Compute conditional and unconditional logits in one batch
+ text_tokens = mx.concatenate([text_tokens, mx.zeros_like(text_tokens)], axis=0)
+
+ head_dim = self.hidden_size // self.num_attention_heads
+ cache = [
+ KVCache(head_dim, self.num_attention_heads) for _ in range(len(self.layers))
+ ]
+ for offset in tqdm(range(max_steps)):
+ audio_input = mx.tile(audio_seq[:, offset : offset + 1], [2, 1, 1])
+ audio_logits = self(audio_input, text_tokens, cache)
+ cond_logits, uncond_logits = audio_logits[:1], audio_logits[1:2]
+ audio_logits = uncond_logits + (cond_logits - uncond_logits) * guidance_coef
+ audio_tokens = top_k_sampling(audio_logits, top_k, temp, axis=-2)
+ # "delay" pattern
+ audio_tokens[..., offset + 1 :] = self.bos_token_id
+ audio_tokens[..., : -max_steps + offset] = self.bos_token_id
+ audio_seq[:, offset + 1 : offset + 2] = audio_tokens
+ mx.eval(audio_seq)
+
+ # Undo delay
+ for i in range(self.num_codebooks):
+ audio_seq[:, : -self.num_codebooks, i] = audio_seq[
+ :, i : -self.num_codebooks + i, i
+ ]
+ audio_seq = audio_seq[:, 1 : -self.num_codebooks + 1]
+
+ audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
+ audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
+ return audio[0]
+
+ @classmethod
+ def sanitize(cls, weights):
+ out_weights = {}
+ for k, arr in weights.items():
+ if k.startswith("transformer."):
+ k = k[len("transformer.") :]
+
+ if "cross_attention" in k:
+ k = k.replace("cross_attention", "cross_attn")
+
+ if "condition_provider" in k:
+ k = k.replace(
+ "condition_provider.conditioners.description", "text_conditioner"
+ )
+
+ if "in_proj_weight" in k:
+ dim = arr.shape[0] // 3
+ name = "in_proj_weight"
+ out_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
+ out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
+ out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
+ continue
+
+ out_weights[k] = arr
+ return out_weights
+
+ @classmethod
+ def from_pretrained(cls, path_or_repo: str):
+ import torch
+ from huggingface_hub import snapshot_download
+
+ path = Path(path_or_repo)
+ if not path.exists():
+ path = Path(
+ snapshot_download(
+ repo_id=path_or_repo,
+ allow_patterns=["*.json", "state_dict.bin"],
+ )
+ )
+
+ with open(path / "config.json", "r") as f:
+ config = SimpleNamespace(**json.load(f))
+ config.text_encoder = SimpleNamespace(**config.text_encoder)
+ config.audio_encoder = SimpleNamespace(**config.audio_encoder)
+ config.decoder = SimpleNamespace(**config.decoder)
+
+ weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
+ weights = {k: mx.array(v) for k, v in weights.items()}
+ weights = cls.sanitize(weights)
+
+ model = MusicGen(config)
+ model.load_weights(list(weights.items()))
+ return model
diff --git a/musicgen/requirements.txt b/musicgen/requirements.txt
new file mode 100644
index 00000000..5c716fe3
--- /dev/null
+++ b/musicgen/requirements.txt
@@ -0,0 +1,6 @@
+mlx>=0.18
+numpy
+huggingface_hub
+torch
+transformers
+scipy
diff --git a/musicgen/t5.py b/musicgen/t5.py
new file mode 120000
index 00000000..f31e26f9
--- /dev/null
+++ b/musicgen/t5.py
@@ -0,0 +1 @@
+../t5/t5.py
\ No newline at end of file
diff --git a/musicgen/utils.py b/musicgen/utils.py
new file mode 100644
index 00000000..78e92571
--- /dev/null
+++ b/musicgen/utils.py
@@ -0,0 +1,15 @@
+# Copyright © 2024 Apple Inc.
+
+import mlx.core as mx
+import numpy as np
+
+
+def save_audio(file: str, audio: mx.array, sampling_rate: int):
+ """
+ Save audio to a wave (.wav) file.
+ """
+ from scipy.io.wavfile import write
+
+ audio = mx.clip(audio, -1, 1)
+ audio = (audio * 32767).astype(mx.int16)
+ write(file, sampling_rate, np.array(audio))
diff --git a/t5/README.md b/t5/README.md
index a0cc861b..e5165f8f 100644
--- a/t5/README.md
+++ b/t5/README.md
@@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.:
This example also supports the FLAN-T5 models variants.[^2]
-## Setup
-
-Download and convert the model:
-
-```sh
-python convert.py --model
-```
-
-This will make the `.npz` file which MLX can read.
-
-The `` can be any of the following:
-
-| Model Name | Model Size |
-| ---------- | ----------
-| t5-small | 60 million |
-| t5-base | 220 million |
-| t5-large | 770 million |
-| t5-3b | 3 billion |
-| t5-11b | 11 billion |
-
-The FLAN variants can be specified with `google/flan-t5-small`,
-`google/flan-t5-base`, etc. See the [Hugging Face
-page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
-complete list of models.
-
## Generate
Generate text with:
@@ -48,6 +23,21 @@ To see a list of options run:
python t5.py --help
```
+The `` can be any of the following:
+
+| Model Name | Model Size |
+| ---------- | ----------
+| t5-small | 60 million |
+| t5-base | 220 million |
+| t5-large | 770 million |
+| t5-3b | 3 billion |
+| t5-11b | 11 billion |
+
+The FLAN variants can be specified with `google/flan-t5-small`,
+`google/flan-t5-base`, etc. See the [Hugging Face
+page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
+complete list of models.
+
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
diff --git a/t5/convert.py b/t5/convert.py
deleted file mode 100644
index e2108a0c..00000000
--- a/t5/convert.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import numpy as np
-from transformers import T5ForConditionalGeneration
-
-SHARED_REPLACEMENT_PATTERNS = [
- (".block.", ".layers."),
- (".k.", ".key_proj."),
- (".o.", ".out_proj."),
- (".q.", ".query_proj."),
- (".v.", ".value_proj."),
- ("shared.", "wte."),
- ("lm_head.", "lm_head.linear."),
- (".layer.0.layer_norm.", ".ln1."),
- (".layer.1.layer_norm.", ".ln2."),
- (".layer.2.layer_norm.", ".ln3."),
- (".final_layer_norm.", ".ln."),
- (
- "layers.0.layer.0.SelfAttention.relative_attention_bias.",
- "relative_attention_bias.embeddings.",
- ),
-]
-
-ENCODER_REPLACEMENT_PATTERNS = [
- (".layer.0.SelfAttention.", ".attention."),
- (".layer.1.DenseReluDense.", ".dense."),
-]
-
-DECODER_REPLACEMENT_PATTERNS = [
- (".layer.0.SelfAttention.", ".self_attention."),
- (".layer.1.EncDecAttention.", ".cross_attention."),
- (".layer.2.DenseReluDense.", ".dense."),
-]
-
-
-def replace_key(key: str) -> str:
- for old, new in SHARED_REPLACEMENT_PATTERNS:
- key = key.replace(old, new)
- if key.startswith("encoder."):
- for old, new in ENCODER_REPLACEMENT_PATTERNS:
- key = key.replace(old, new)
- elif key.startswith("decoder."):
- for old, new in DECODER_REPLACEMENT_PATTERNS:
- key = key.replace(old, new)
- return key
-
-
-def convert(model_name, dtype):
- dtype = getattr(np, dtype)
- model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
- weights = {
- replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
- }
- file_name = model_name.replace("/", "-")
- print(f"Saving weights to {file_name}.npz")
- np.savez(f"{file_name}.npz", **weights)
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
- parser.add_argument(
- "--model",
- type=str,
- help="Name of the T5 model.",
- default="t5-small",
- )
- parser.add_argument(
- "--dtype",
- help="The model data type.",
- type=str,
- choices=["float16", "float32"],
- default="float32",
- )
- args = parser.parse_args()
- convert(args.model, args.dtype)
diff --git a/t5/t5.py b/t5/t5.py
index 89f2e486..04a0da8c 100644
--- a/t5/t5.py
+++ b/t5/t5.py
@@ -1,12 +1,45 @@
import argparse
+import json
+from pathlib import Path
from time import perf_counter_ns
+from types import SimpleNamespace
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from mlx.utils import tree_map, tree_unflatten
-from transformers import AutoTokenizer, T5Config
+from transformers import AutoTokenizer
+
+
+class Tokenizer:
+ def __init__(self, config, model_name):
+ self._decoder_start_id = config.decoder_start_token_id
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ model_name,
+ legacy=False,
+ model_max_length=getattr(config, "n_positions", 512),
+ )
+
+ @property
+ def eos_id(self) -> int:
+ return self._tokenizer.eos_token_id
+
+ @property
+ def decoder_start_id(self) -> int:
+ return self._decoder_start_id
+
+ def encode(self, s: str) -> mx.array:
+ return mx.array(
+ self._tokenizer(
+ s,
+ return_tensors="np",
+ return_attention_mask=False,
+ )["input_ids"]
+ )
+
+ def decode(self, t: List[int], with_sep: bool = True) -> str:
+ tokens = self._tokenizer.convert_ids_to_tokens(t)
+ return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
def _relative_position_bucket(
@@ -60,10 +93,10 @@ def _relative_position_bucket(
class RelativePositionBias(nn.Module):
- def __init__(self, config: T5Config, bidirectional: bool):
+ def __init__(self, config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
- self.max_distance = config.relative_attention_max_distance
+ self.max_distance = getattr(config, "relative_attention_max_distance", 128)
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
@@ -91,7 +124,7 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
@@ -135,17 +168,21 @@ class MultiHeadAttention(nn.Module):
class DenseActivation(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
- self.gated = config.feed_forward_proj.startswith("gated")
+ self.gated = hasattr(config, "feed_forward_proj")
+ activation = (
+ "relu"
+ if not self.gated
+ else config.feed_forward_proj.removeprefix("gated-")
+ )
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
- activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
@@ -166,7 +203,7 @@ class DenseActivation(nn.Module):
class TransformerEncoderLayer(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
@@ -184,7 +221,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerEncoder(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
@@ -200,7 +237,7 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
@@ -233,7 +270,7 @@ class TransformerDecoderLayer(nn.Module):
class TransformerDecoder(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
@@ -262,7 +299,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
@@ -270,11 +307,11 @@ class OutputHead(nn.Module):
class T5(nn.Module):
- def __init__(self, config: T5Config):
+ def __init__(self, config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
- self.tie_word_embeddings = config.tie_word_embeddings
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
@@ -313,36 +350,82 @@ class T5(nn.Module):
):
return self.decode(decoder_inputs, self.encode(inputs))[0]
+ @classmethod
+ def sanitize(cls, weights):
+ shared_replacement_patterns = [
+ (".block.", ".layers."),
+ (".k.", ".key_proj."),
+ (".o.", ".out_proj."),
+ (".q.", ".query_proj."),
+ (".v.", ".value_proj."),
+ ("shared.", "wte."),
+ ("lm_head.", "lm_head.linear."),
+ (".layer.0.layer_norm.", ".ln1."),
+ (".layer.1.layer_norm.", ".ln2."),
+ (".layer.2.layer_norm.", ".ln3."),
+ (".final_layer_norm.", ".ln."),
+ (
+ "layers.0.layer.0.SelfAttention.relative_attention_bias.",
+ "relative_attention_bias.embeddings.",
+ ),
+ ]
-class Tokenizer:
- def __init__(self, config: T5Config):
- self._decoder_start_id = config.decoder_start_token_id
- self._tokenizer = AutoTokenizer.from_pretrained(
- args.model,
- legacy=False,
- model_max_length=getattr(config, "n_positions", 512),
- )
+ encoder_replacement_patterns = [
+ (".layer.0.SelfAttention.", ".attention."),
+ (".layer.1.DenseReluDense.", ".dense."),
+ ]
- @property
- def eos_id(self) -> int:
- return self._tokenizer.eos_token_id
+ decoder_replacement_patterns = [
+ (".layer.0.SelfAttention.", ".self_attention."),
+ (".layer.1.EncDecAttention.", ".cross_attention."),
+ (".layer.2.DenseReluDense.", ".dense."),
+ ]
- @property
- def decoder_start_id(self) -> int:
- return self._decoder_start_id
+ ignored_keys = [
+ "decoder.layers.0.cross_attention.relative_attention_bias.weight"
+ ]
- def encode(self, s: str) -> mx.array:
- return mx.array(
- self._tokenizer(
- s,
- return_tensors="np",
- return_attention_mask=False,
- )["input_ids"]
- )
+ def replace_key(key: str) -> str:
+ for old, new in shared_replacement_patterns:
+ key = key.replace(old, new)
+ if key.startswith("encoder."):
+ for old, new in encoder_replacement_patterns:
+ key = key.replace(old, new)
+ elif key.startswith("decoder."):
+ for old, new in decoder_replacement_patterns:
+ key = key.replace(old, new)
+ return key
- def decode(self, t: List[int], with_sep: bool = True) -> str:
- tokens = self._tokenizer.convert_ids_to_tokens(t)
- return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
+ weights = {replace_key(k): v for k, v in weights.items()}
+ for key in ignored_keys:
+ if key in weights:
+ del weights[key]
+ return weights
+
+ @classmethod
+ def from_pretrained(
+ cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
+ ) -> tuple["T5", Tokenizer]:
+ from huggingface_hub import snapshot_download
+
+ path = Path(path_or_repo)
+ if not path.exists():
+ path = Path(
+ snapshot_download(
+ repo_id=path_or_repo,
+ allow_patterns=["*.json", "*.safetensors", "*.model"],
+ )
+ )
+
+ with open(path / "config.json", "r") as f:
+ config = SimpleNamespace(**json.load(f))
+
+ model = T5(config)
+ weights = mx.load(str(path / "model.safetensors"))
+ weights = cls.sanitize(weights)
+ weights = {k: v.astype(dtype) for k, v in weights.items()}
+ model.load_weights(list(weights.items()))
+ return model, Tokenizer(config, "t5-base")
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
@@ -363,19 +446,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
yield y.squeeze()
-def load_model(model_name: str, dtype: str = "float16"):
- config = T5Config.from_pretrained(args.model)
- dtype = getattr(mx, dtype)
- model = T5(config)
- file_name = model_name.replace("/", "-")
- weights = mx.load(f"{file_name}.npz")
- weights = tree_unflatten(list(weights.items()))
- weights = tree_map(lambda p: p.astype(dtype), weights)
- model.update(weights)
- mx.eval(model.parameters())
- return model, Tokenizer(config)
-
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
@@ -421,7 +491,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
- model, tokenizer = load_model(args.model, args.dtype)
+ dtype = getattr(mx, args.dtype)
+ model, tokenizer = T5.from_pretrained(args.model, dtype)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)
From a5f2bab070c96d21905b5c6ae564c2fbd64990b2 Mon Sep 17 00:00:00 2001
From: Angelos Katharopoulos
Date: Fri, 11 Oct 2024 21:17:41 -0700
Subject: [PATCH 047/188] Add FLUX finetuning (#1028)
---
flux/README.md | 185 ++++++++++++
flux/dreambooth.py | 378 ++++++++++++++++++++++++
flux/flux/__init__.py | 248 ++++++++++++++++
flux/flux/autoencoder.py | 357 ++++++++++++++++++++++
flux/flux/clip.py | 154 ++++++++++
flux/flux/layers.py | 302 +++++++++++++++++++
flux/flux/lora.py | 76 +++++
flux/flux/model.py | 134 +++++++++
flux/flux/sampler.py | 56 ++++
flux/flux/t5.py | 244 +++++++++++++++
flux/flux/tokenizers.py | 185 ++++++++++++
flux/flux/utils.py | 209 +++++++++++++
flux/requirements.txt | 7 +
flux/static/dog-r4-g8-1200-512x1024.png | Bin 0 -> 772523 bytes
flux/static/dog-r4-g8-1200.png | Bin 0 -> 433610 bytes
flux/static/dog6.png | Bin 0 -> 444359 bytes
flux/static/generated-mlx.png | Bin 0 -> 156875 bytes
flux/txt2image.py | 150 ++++++++++
18 files changed, 2685 insertions(+)
create mode 100644 flux/README.md
create mode 100644 flux/dreambooth.py
create mode 100644 flux/flux/__init__.py
create mode 100644 flux/flux/autoencoder.py
create mode 100644 flux/flux/clip.py
create mode 100644 flux/flux/layers.py
create mode 100644 flux/flux/lora.py
create mode 100644 flux/flux/model.py
create mode 100644 flux/flux/sampler.py
create mode 100644 flux/flux/t5.py
create mode 100644 flux/flux/tokenizers.py
create mode 100644 flux/flux/utils.py
create mode 100644 flux/requirements.txt
create mode 100644 flux/static/dog-r4-g8-1200-512x1024.png
create mode 100644 flux/static/dog-r4-g8-1200.png
create mode 100644 flux/static/dog6.png
create mode 100644 flux/static/generated-mlx.png
create mode 100644 flux/txt2image.py
diff --git a/flux/README.md b/flux/README.md
new file mode 100644
index 00000000..33bebfd6
--- /dev/null
+++ b/flux/README.md
@@ -0,0 +1,185 @@
+FLUX
+====
+
+FLUX implementation in MLX. The implementation is ported directly from
+[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
+and the model weights are downloaded directly from the Hugging Face Hub.
+
+The goal of this example is to be clean, educational and to allow for
+experimentation with finetuning FLUX models as well as adding extra
+functionality such as in-/outpainting, guidance with custom losses etc.
+
+
+*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
+tron emanating futuristic technology with the word "MLX" in the center with
+capital red letters.'*
+
+Installation
+------------
+
+The dependencies are minimal, namely:
+
+- `huggingface-hub` to download the checkpoints.
+- `regex` for the tokenization
+- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
+- `sentencepiece` for the T5 tokenizer
+
+You can install all of the above with the `requirements.txt` as follows:
+
+ pip install -r requirements.txt
+
+Inference
+---------
+
+Inference in this example is similar to the stable diffusion example. The
+classes to get you started are `FluxPipeline` from the `flux` module.
+
+```python
+import mlx.core as mx
+from flux import FluxPipeline
+
+# This will download all the weights from HF hub
+flux = FluxPipeline("flux-schnell")
+
+# Make a generator that returns the latent variables from the reverse diffusion
+# process
+latent_generator = flux.generate_latents(
+ "A photo of an astronaut riding a horse on Mars",
+ num_steps=4,
+ latent_size=(32, 64), # 256x512 image
+)
+
+# The first return value of the generator contains the conditioning and the
+# random noise at the beginning of the diffusion process.
+conditioning = next(latent_generator)
+(
+ x_T, # The initial noise
+ x_positions, # The integer positions used for image positional encoding
+ t5_conditioning, # The T5 features from the text prompt
+ t5_positions, # Integer positions for text (normally all 0s)
+ clip_conditioning, # The clip text features from the text prompt
+) = conditioning
+
+# Returning the conditioning as the first output from the generator allows us
+# to unload T5 and clip before running the diffusion transformer.
+mx.eval(conditioning)
+
+# Evaluate each diffusion step
+for x_t in latent_generator:
+ mx.eval(x_t)
+
+# Note that we need to pass the latent size because it is collapsed and
+# patchified in x_t and we need to unwrap it.
+img = flux.decode(x_t, latent_size=(32, 64))
+```
+
+The above are essentially the implementation of the `txt2image.py` script
+except for some additional logic to quantize and/or load trained adapters. One
+can use the script as follows:
+
+```shell
+python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.'
+```
+
+### Experimental Options
+
+FLUX pads the prompt to a specific size of 512 tokens for the dev model and
+256 for the schnell model. Not applying padding results in faster generation
+but it is not clear how it may affect the generated images. To enable that
+option in this example pass `--no-t5-padding` to the `txt2image.py` script or
+instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
+
+Finetuning
+----------
+
+The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
+but ymmv) on a provided image dataset. The dataset folder must have an
+`index.json` file with the following format:
+
+```json
+{
+ "data": [
+ {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
+ {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
+ {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
+ ...
+ ]
+}
+```
+
+The training script by default trains for 600 iterations with a batch size of
+1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
+--help` for the list of hyperparameters you can tune.
+
+> [!Note]
+> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
+> should reduce this number significantly.
+
+### Training Example
+
+This is a step-by-step finetuning example. We will be using the data from
+[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
+In particular, we will use `dog6` which is a popular example for showcasing
+dreambooth [^1].
+
+The training images are the following 5 images [^2]:
+
+
+
+We start by making the following `index.json` file and placing it in the same
+folder as the images.
+
+```json
+{
+ "data": [
+ {"image": "00.jpg", "text": "A photo of sks dog"},
+ {"image": "01.jpg", "text": "A photo of sks dog"},
+ {"image": "02.jpg", "text": "A photo of sks dog"},
+ {"image": "03.jpg", "text": "A photo of sks dog"},
+ {"image": "04.jpg", "text": "A photo of sks dog"}
+ ]
+}
+```
+
+Subsequently we finetune FLUX using the following command:
+
+```shell
+python dreambooth.py \
+ --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
+ --progress-every 600 --iterations 1200 --learning-rate 0.0001 \
+ --lora-rank 4 --grad-accumulate 8 \
+ path/to/dreambooth/dataset/dog6
+```
+
+The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
+bit more than 1 hour.
+
+### Using the Adapter
+
+The adapters are saved in `mlx_output` and can be used directly by the
+`txt2image.py` script. For instance,
+
+```shell
+python txt2img.py --model dev --save-raw --image-size 512x512 --n-images 1 \
+ --adapter mlx_output/mlx_output/0001200_adapters.safetensors \
+ --fuse-adapter \
+ --no-t5-padding \
+ 'A photo of an sks dog lying on the sand at a beach in Greece'
+```
+
+generates an image that looks like the following,
+
+
+
+and of course we can pass `--image-size 512x1024` to get larger images with
+different aspect ratios,
+
+
+
+The arguments that are relevant to the adapters are of course `--adapter` and
+`--fuse-adapter`. The first defines the path to an adapter to apply to the
+model and the second fuses the adapter back into the model to get a bit more
+speed during generation.
+
+[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
+[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
diff --git a/flux/dreambooth.py b/flux/dreambooth.py
new file mode 100644
index 00000000..4a4dbb08
--- /dev/null
+++ b/flux/dreambooth.py
@@ -0,0 +1,378 @@
+# Copyright © 2024 Apple Inc.
+
+import argparse
+import json
+import time
+from functools import partial
+from pathlib import Path
+
+import mlx.core as mx
+import mlx.nn as nn
+import mlx.optimizers as optim
+import numpy as np
+from mlx.nn.utils import average_gradients
+from mlx.utils import tree_flatten, tree_map, tree_reduce
+from PIL import Image
+from tqdm import tqdm
+
+from flux import FluxPipeline
+
+
+class FinetuningDataset:
+ def __init__(self, flux, args):
+ self.args = args
+ self.flux = flux
+ self.dataset_base = Path(args.dataset)
+ dataset_index = self.dataset_base / "index.json"
+ if not dataset_index.exists():
+ raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
+ with open(dataset_index, "r") as f:
+ self.index = json.load(f)
+
+ self.latents = []
+ self.t5_features = []
+ self.clip_features = []
+
+ def _random_crop_resize(self, img):
+ resolution = self.args.resolution
+ width, height = img.size
+
+ a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
+
+ # Random crop the input image between 0.8 to 1.0 of its original dimensions
+ crop_size = (
+ max((0.8 + 0.2 * a) * width, resolution[0]),
+ max((0.8 + 0.2 * a) * height, resolution[1]),
+ )
+ pan = (width - crop_size[0], height - crop_size[1])
+ img = img.crop(
+ (
+ pan[0] * b,
+ pan[1] * c,
+ crop_size[0] + pan[0] * b,
+ crop_size[1] + pan[1] * c,
+ )
+ )
+
+ # Fit the largest rectangle with the ratio of resolution in the image
+ # rectangle.
+ width, height = crop_size
+ ratio = resolution[0] / resolution[1]
+ r1 = (height * ratio, height)
+ r2 = (width, width / ratio)
+ r = r1 if r1[0] <= width else r2
+ img = img.crop(
+ (
+ (width - r[0]) / 2,
+ (height - r[1]) / 2,
+ (width + r[0]) / 2,
+ (height + r[1]) / 2,
+ )
+ )
+
+ # Finally resize the image to resolution
+ img = img.resize(resolution, Image.LANCZOS)
+
+ return mx.array(np.array(img))
+
+ def encode_images(self):
+ """Encode the images in the latent space to prepare for training."""
+ self.flux.ae.eval()
+ for sample in tqdm(self.index["data"]):
+ input_img = Image.open(self.dataset_base / sample["image"])
+ for i in range(self.args.num_augmentations):
+ img = self._random_crop_resize(input_img)
+ img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
+ x_0 = self.flux.ae.encode(img[None])
+ x_0 = x_0.astype(self.flux.dtype)
+ mx.eval(x_0)
+ self.latents.append(x_0)
+
+ def encode_prompts(self):
+ """Pre-encode the prompts so that we don't recompute them during
+ training (doesn't allow finetuning the text encoders)."""
+ for sample in tqdm(self.index["data"]):
+ t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
+ t5_feat = self.flux.t5(t5_tok)
+ clip_feat = self.flux.clip(clip_tok).pooled_output
+ mx.eval(t5_feat, clip_feat)
+ self.t5_features.append(t5_feat)
+ self.clip_features.append(clip_feat)
+
+ def iterate(self, batch_size):
+ xs = mx.concatenate(self.latents)
+ t5 = mx.concatenate(self.t5_features)
+ clip = mx.concatenate(self.clip_features)
+ mx.eval(xs, t5, clip)
+ n_aug = self.args.num_augmentations
+ while True:
+ x_indices = mx.random.permutation(len(self.latents))
+ c_indices = x_indices // n_aug
+ for i in range(0, len(self.latents), batch_size):
+ x_i = x_indices[i : i + batch_size]
+ c_i = c_indices[i : i + batch_size]
+ yield xs[x_i], t5[c_i], clip[c_i]
+
+
+def generate_progress_images(iteration, flux, args):
+ """Generate images to monitor the progress of the finetuning."""
+ out_dir = Path(args.output_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ out_file = out_dir / f"{iteration:07d}_progress.png"
+ print(f"Generating {str(out_file)}", flush=True)
+
+ # Generate some images and arrange them in a grid
+ n_rows = 2
+ n_images = 4
+ x = flux.generate_images(
+ args.progress_prompt,
+ n_images,
+ args.progress_steps,
+ )
+ x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
+ B, H, W, C = x.shape
+ x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
+ x = x.reshape(n_rows * H, B // n_rows * W, C)
+ x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
+ x = (x * 255).astype(mx.uint8)
+
+ # Save them to disc
+ im = Image.fromarray(np.array(x))
+ im.save(out_file)
+
+
+def save_adapters(iteration, flux, args):
+ out_dir = Path(args.output_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
+ print(f"Saving {str(out_file)}")
+
+ mx.save_safetensors(
+ str(out_file),
+ dict(tree_flatten(flux.flow.trainable_parameters())),
+ metadata={
+ "lora_rank": str(args.lora_rank),
+ "lora_blocks": str(args.lora_blocks),
+ },
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Finetune Flux to generate images with a specific subject"
+ )
+
+ parser.add_argument(
+ "--model",
+ default="dev",
+ choices=[
+ "dev",
+ "schnell",
+ ],
+ help="Which flux model to train",
+ )
+ parser.add_argument(
+ "--guidance", type=float, default=4.0, help="The guidance factor to use."
+ )
+ parser.add_argument(
+ "--iterations",
+ type=int,
+ default=600,
+ help="How many iterations to train for",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=1,
+ help="The batch size to use when training the stable diffusion model",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=lambda x: tuple(map(int, x.split("x"))),
+ default=(512, 512),
+ help="The resolution of the training images",
+ )
+ parser.add_argument(
+ "--num-augmentations",
+ type=int,
+ default=5,
+ help="Augment the images by random cropping and panning",
+ )
+ parser.add_argument(
+ "--progress-prompt",
+ required=True,
+ help="Use this prompt when generating images for evaluation",
+ )
+ parser.add_argument(
+ "--progress-steps",
+ type=int,
+ default=50,
+ help="Use this many steps when generating images for evaluation",
+ )
+ parser.add_argument(
+ "--progress-every",
+ type=int,
+ default=50,
+ help="Generate images every PROGRESS_EVERY steps",
+ )
+ parser.add_argument(
+ "--checkpoint-every",
+ type=int,
+ default=50,
+ help="Save the model every CHECKPOINT_EVERY steps",
+ )
+ parser.add_argument(
+ "--lora-blocks",
+ type=int,
+ default=-1,
+ help="Train the last LORA_BLOCKS transformer blocks",
+ )
+ parser.add_argument(
+ "--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
+ )
+ parser.add_argument(
+ "--warmup-steps", type=int, default=100, help="Learning rate warmup"
+ )
+ parser.add_argument(
+ "--learning-rate", type=float, default="1e-4", help="Learning rate for training"
+ )
+ parser.add_argument(
+ "--grad-accumulate",
+ type=int,
+ default=4,
+ help="Accumulate gradients for that many iterations before applying them",
+ )
+ parser.add_argument(
+ "--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
+ )
+
+ parser.add_argument("dataset")
+
+ args = parser.parse_args()
+
+ # Load the model and set it up for LoRA training. We use the same random
+ # state when creating the LoRA layers so all workers will have the same
+ # initial weights.
+ mx.random.seed(0x0F0F0F0F)
+ flux = FluxPipeline("flux-" + args.model)
+ flux.flow.freeze()
+ flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
+
+ # Reset the seed to a different seed per worker if we are in distributed
+ # mode so that each worker is working on different data, diffusion step and
+ # random noise.
+ mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
+
+ # Report how many parameters we are training
+ trainable_params = tree_reduce(
+ lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
+ )
+ print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
+
+ # Set up the optimizer and training steps. The steps are a bit verbose to
+ # support gradient accumulation together with compilation.
+ warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
+ cosine = optim.cosine_decay(
+ args.learning_rate, args.iterations // args.grad_accumulate
+ )
+ lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
+ optimizer = optim.Adam(learning_rate=lr_schedule)
+ state = [flux.flow.state, optimizer.state, mx.random.state]
+
+ @partial(mx.compile, inputs=state, outputs=state)
+ def single_step(x, t5_feat, clip_feat, guidance):
+ loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
+ x, t5_feat, clip_feat, guidance
+ )
+ grads = average_gradients(grads)
+ optimizer.update(flux.flow, grads)
+
+ return loss
+
+ @partial(mx.compile, inputs=state, outputs=state)
+ def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
+ return nn.value_and_grad(flux.flow, flux.training_loss)(
+ x, t5_feat, clip_feat, guidance
+ )
+
+ @partial(mx.compile, inputs=state, outputs=state)
+ def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
+ loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
+ x, t5_feat, clip_feat, guidance
+ )
+ grads = tree_map(lambda a, b: a + b, prev_grads, grads)
+ return loss, grads
+
+ @partial(mx.compile, inputs=state, outputs=state)
+ def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
+ loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
+ x, t5_feat, clip_feat, guidance
+ )
+ grads = tree_map(
+ lambda a, b: (a + b) / args.grad_accumulate,
+ prev_grads,
+ grads,
+ )
+ grads = average_gradients(grads)
+ optimizer.update(flux.flow, grads)
+
+ return loss
+
+ # We simply route to the appropriate step based on whether we have
+ # gradients from a previous step and whether we should be performing an
+ # update or simply computing and accumulating gradients in this step.
+ def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
+ if prev_grads is None:
+ if perform_step:
+ return single_step(x, t5_feat, clip_feat, guidance), None
+ else:
+ return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
+ else:
+ if perform_step:
+ return (
+ grad_accumulate_and_step(
+ x, t5_feat, clip_feat, guidance, prev_grads
+ ),
+ None,
+ )
+ else:
+ return compute_loss_and_accumulate_grads(
+ x, t5_feat, clip_feat, guidance, prev_grads
+ )
+
+ print("Create the training dataset.", flush=True)
+ dataset = FinetuningDataset(flux, args)
+ dataset.encode_images()
+ dataset.encode_prompts()
+ guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
+
+ # An initial generation to compare
+ generate_progress_images(0, flux, args)
+
+ grads = None
+ losses = []
+ tic = time.time()
+ for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
+ loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
+ mx.eval(loss, grads, state)
+ losses.append(loss.item())
+
+ if (i + 1) % 10 == 0:
+ toc = time.time()
+ peak_mem = mx.metal.get_peak_memory() / 1024**3
+ print(
+ f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
+ f"It/s: {10 / (toc - tic):.3f} "
+ f"Peak mem: {peak_mem:.3f} GB",
+ flush=True,
+ )
+
+ if (i + 1) % args.progress_every == 0:
+ generate_progress_images(i + 1, flux, args)
+
+ if (i + 1) % args.checkpoint_every == 0:
+ save_adapters(i + 1, flux, args)
+
+ if (i + 1) % 10 == 0:
+ losses = []
+ tic = time.time()
diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py
new file mode 100644
index 00000000..8d39d605
--- /dev/null
+++ b/flux/flux/__init__.py
@@ -0,0 +1,248 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+import time
+from typing import Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.utils import tree_unflatten
+from tqdm import tqdm
+
+from .lora import LoRALinear
+from .sampler import FluxSampler
+from .utils import (
+ load_ae,
+ load_clip,
+ load_clip_tokenizer,
+ load_flow_model,
+ load_t5,
+ load_t5_tokenizer,
+)
+
+
+class FluxPipeline:
+ def __init__(self, name: str, t5_padding: bool = True):
+ self.dtype = mx.bfloat16
+ self.name = name
+ self.t5_padding = t5_padding
+
+ self.ae = load_ae(name)
+ self.flow = load_flow_model(name)
+ self.clip = load_clip(name)
+ self.clip_tokenizer = load_clip_tokenizer(name)
+ self.t5 = load_t5(name)
+ self.t5_tokenizer = load_t5_tokenizer(name)
+ self.sampler = FluxSampler(name)
+
+ def ensure_models_are_loaded(self):
+ mx.eval(
+ self.ae.parameters(),
+ self.flow.parameters(),
+ self.clip.parameters(),
+ self.t5.parameters(),
+ )
+
+ def reload_text_encoders(self):
+ self.t5 = load_t5(self.name)
+ self.clip = load_clip(self.name)
+
+ def tokenize(self, text):
+ t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
+ clip_tokens = self.clip_tokenizer.encode(text)
+ return t5_tokens, clip_tokens
+
+ def _prepare_latent_images(self, x):
+ b, h, w, c = x.shape
+
+ # Pack the latent image to 2x2 patches
+ x = x.reshape(b, h // 2, 2, w // 2, 2, c)
+ x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
+
+ # Create positions ids used to positionally encode each patch. Due to
+ # the way RoPE works, this results in an interesting positional
+ # encoding where parts of the feature are holding different positional
+ # information. Namely, the first part holds information independent of
+ # the spatial position (hence 0s), the 2nd part holds vertical spatial
+ # information and the last one horizontal.
+ i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
+ j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
+ x_ids = mx.stack([i, j, k], axis=-1)
+ x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
+
+ return x, x_ids
+
+ def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
+ # Prepare the text features
+ txt = self.t5(t5_tokens)
+ if len(txt) == 1 and n_images > 1:
+ txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
+ txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
+
+ # Prepare the clip text features
+ vec = self.clip(clip_tokens).pooled_output
+ if len(vec) == 1 and n_images > 1:
+ vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
+
+ return txt, txt_ids, vec
+
+ def _denoising_loop(
+ self,
+ x_t,
+ x_ids,
+ txt,
+ txt_ids,
+ vec,
+ num_steps: int = 35,
+ guidance: float = 4.0,
+ start: float = 1,
+ stop: float = 0,
+ ):
+ B = len(x_t)
+
+ def scalar(x):
+ return mx.full((B,), x, dtype=self.dtype)
+
+ guidance = scalar(guidance)
+ timesteps = self.sampler.timesteps(
+ num_steps,
+ x_t.shape[1],
+ start=start,
+ stop=stop,
+ )
+ for i in range(num_steps):
+ t = timesteps[i]
+ t_prev = timesteps[i + 1]
+
+ pred = self.flow(
+ img=x_t,
+ img_ids=x_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=scalar(t),
+ guidance=guidance,
+ )
+ x_t = self.sampler.step(pred, x_t, t, t_prev)
+
+ yield x_t
+
+ def generate_latents(
+ self,
+ text: str,
+ n_images: int = 1,
+ num_steps: int = 35,
+ guidance: float = 4.0,
+ latent_size: Tuple[int, int] = (64, 64),
+ seed=None,
+ ):
+ # Set the PRNG state
+ if seed is not None:
+ mx.random.seed(seed)
+
+ # Create the latent variables
+ x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
+ x_T, x_ids = self._prepare_latent_images(x_T)
+
+ # Get the conditioning
+ t5_tokens, clip_tokens = self.tokenize(text)
+ txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
+
+ # Yield the conditioning for controlled evaluation by the caller
+ yield (x_T, x_ids, txt, txt_ids, vec)
+
+ # Yield the latent sequences from the denoising loop
+ yield from self._denoising_loop(
+ x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
+ )
+
+ def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
+ h, w = latent_size
+ x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
+ x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
+ x = self.ae.decode(x)
+ return mx.clip(x + 1, 0, 2) * 0.5
+
+ def generate_images(
+ self,
+ text: str,
+ n_images: int = 1,
+ num_steps: int = 35,
+ guidance: float = 4.0,
+ latent_size: Tuple[int, int] = (64, 64),
+ seed=None,
+ reload_text_encoders: bool = True,
+ progress: bool = True,
+ ):
+ latents = self.generate_latents(
+ text, n_images, num_steps, guidance, latent_size, seed
+ )
+ mx.eval(next(latents))
+
+ if reload_text_encoders:
+ self.reload_text_encoders()
+
+ for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
+ mx.eval(x_t)
+
+ images = []
+ for i in tqdm(range(len(x_t)), disable=not progress):
+ images.append(self.decode(x_t[i : i + 1]))
+ mx.eval(images[-1])
+ images = mx.concatenate(images, axis=0)
+ mx.eval(images)
+
+ return images
+
+ def training_loss(
+ self,
+ x_0: mx.array,
+ t5_features: mx.array,
+ clip_features: mx.array,
+ guidance: mx.array,
+ ):
+ # Get the text conditioning
+ txt = t5_features
+ txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
+ vec = clip_features
+
+ # Prepare the latent input
+ x_0, x_ids = self._prepare_latent_images(x_0)
+
+ # Forward process
+ t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
+ eps = mx.random.normal(x_0.shape, dtype=self.dtype)
+ x_t = self.sampler.add_noise(x_0, t, noise=eps)
+ x_t = mx.stop_gradient(x_t)
+
+ # Do the denoising
+ pred = self.flow(
+ img=x_t,
+ img_ids=x_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t,
+ guidance=guidance,
+ )
+
+ return (pred + x_0 - eps).square().mean()
+
+ def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
+ """Swap the linear layers in the transformer blocks with LoRA layers."""
+ all_blocks = self.flow.double_blocks + self.flow.single_blocks
+ all_blocks.reverse()
+ num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
+ for i, block in zip(range(num_blocks), all_blocks):
+ loras = []
+ for name, module in block.named_modules():
+ if isinstance(module, nn.Linear):
+ loras.append((name, LoRALinear.from_base(module, r=rank)))
+ block.update_modules(tree_unflatten(loras))
+
+ def fuse_lora_layers(self):
+ fused_layers = []
+ for name, module in self.flow.named_modules():
+ if isinstance(module, LoRALinear):
+ fused_layers.append((name, module.fuse()))
+ self.flow.update_modules(tree_unflatten(fused_layers))
diff --git a/flux/flux/autoencoder.py b/flux/flux/autoencoder.py
new file mode 100644
index 00000000..6332bb57
--- /dev/null
+++ b/flux/flux/autoencoder.py
@@ -0,0 +1,357 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import List
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx.nn.layers.upsample import upsample_nearest
+
+
+@dataclass
+class AutoEncoderParams:
+ resolution: int
+ in_channels: int
+ ch: int
+ out_ch: int
+ ch_mult: List[int]
+ num_res_blocks: int
+ z_channels: int
+ scale_factor: float
+ shift_factor: float
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = nn.GroupNorm(
+ num_groups=32,
+ dims=in_channels,
+ eps=1e-6,
+ affine=True,
+ pytorch_compatible=True,
+ )
+ self.q = nn.Linear(in_channels, in_channels)
+ self.k = nn.Linear(in_channels, in_channels)
+ self.v = nn.Linear(in_channels, in_channels)
+ self.proj_out = nn.Linear(in_channels, in_channels)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ B, H, W, C = x.shape
+
+ y = x.reshape(B, 1, -1, C)
+ y = self.norm(y)
+ q = self.q(y)
+ k = self.k(y)
+ v = self.v(y)
+ y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
+ y = self.proj_out(y)
+
+ return x + y.reshape(B, H, W, C)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = nn.GroupNorm(
+ num_groups=32,
+ dims=in_channels,
+ eps=1e-6,
+ affine=True,
+ pytorch_compatible=True,
+ )
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.norm2 = nn.GroupNorm(
+ num_groups=32,
+ dims=out_channels,
+ eps=1e-6,
+ affine=True,
+ pytorch_compatible=True,
+ )
+ self.conv2 = nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Linear(in_channels, out_channels)
+
+ def __call__(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nn.silu(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = nn.silu(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def __call__(self, x: mx.array):
+ x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def __call__(self, x: mx.array):
+ x = upsample_nearest(x, (2, 2))
+ x = self.conv(x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = []
+ block_in = self.ch
+ for i_level in range(self.num_resolutions):
+ block = []
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ down = {}
+ down["block"] = block
+ down["attn"] = attn
+ if i_level != self.num_resolutions - 1:
+ down["downsample"] = Downsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = {}
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid["attn_1"] = AttnBlock(block_in)
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # end
+ self.norm_out = nn.GroupNorm(
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
+ )
+ self.conv_out = nn.Conv2d(
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def __call__(self, x: mx.array):
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level]["block"][i_block](hs[-1])
+
+ # TODO: Remove the attn
+ if len(self.down[i_level]["attn"]) > 0:
+ h = self.down[i_level]["attn"][i_block](h)
+
+ hs.append(h)
+
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level]["downsample"](hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid["block_1"](h)
+ h = self.mid["attn_1"](h)
+ h = self.mid["block_2"](h)
+
+ # end
+ h = self.norm_out(h)
+ h = nn.silu(h)
+ h = self.conv_out(h)
+
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ ch: int,
+ out_ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ in_channels: int,
+ resolution: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.ffactor = 2 ** (self.num_resolutions - 1)
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = {}
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid["attn_1"] = AttnBlock(block_in)
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # upsampling
+ self.up = []
+ for i_level in reversed(range(self.num_resolutions)):
+ block = []
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
+
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ up = {}
+ up["block"] = block
+ up["attn"] = attn
+ if i_level != 0:
+ up["upsample"] = Upsample(block_in)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = nn.GroupNorm(
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
+ )
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def __call__(self, z: mx.array):
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid["block_1"](h)
+ h = self.mid["attn_1"](h)
+ h = self.mid["block_2"](h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level]["block"][i_block](h)
+
+ # TODO: Remove the attn
+ if len(self.up[i_level]["attn"]) > 0:
+ h = self.up[i_level]["attn"][i_block](h)
+
+ if i_level != 0:
+ h = self.up[i_level]["upsample"](h)
+
+ # end
+ h = self.norm_out(h)
+ h = nn.silu(h)
+ h = self.conv_out(h)
+
+ return h
+
+
+class DiagonalGaussian(nn.Module):
+ def __call__(self, z: mx.array):
+ mean, logvar = mx.split(z, 2, axis=-1)
+ if self.training:
+ std = mx.exp(0.5 * logvar)
+ eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
+ return mean + std * eps
+ else:
+ return mean
+
+
+class AutoEncoder(nn.Module):
+ def __init__(self, params: AutoEncoderParams):
+ super().__init__()
+ self.encoder = Encoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.decoder = Decoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ out_ch=params.out_ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.reg = DiagonalGaussian()
+
+ self.scale_factor = params.scale_factor
+ self.shift_factor = params.shift_factor
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for k, w in weights.items():
+ if w.ndim == 4:
+ w = w.transpose(0, 2, 3, 1)
+ w = w.reshape(-1).reshape(w.shape)
+ if w.shape[1:3] == (1, 1):
+ w = w.squeeze((1, 2))
+ new_weights[k] = w
+ return new_weights
+
+ def encode(self, x: mx.array):
+ z = self.reg(self.encoder(x))
+ z = self.scale_factor * (z - self.shift_factor)
+ return z
+
+ def decode(self, z: mx.array):
+ z = z / self.scale_factor + self.shift_factor
+ return self.decoder(z)
+
+ def __call__(self, x: mx.array):
+ return self.decode(self.encode(x))
diff --git a/flux/flux/clip.py b/flux/flux/clip.py
new file mode 100644
index 00000000..d5a30dbf
--- /dev/null
+++ b/flux/flux/clip.py
@@ -0,0 +1,154 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
+
+
+@dataclass
+class CLIPTextModelConfig:
+ num_layers: int = 23
+ model_dims: int = 1024
+ num_heads: int = 16
+ max_length: int = 77
+ vocab_size: int = 49408
+ hidden_act: str = "quick_gelu"
+
+ @classmethod
+ def from_dict(cls, config):
+ return cls(
+ num_layers=config["num_hidden_layers"],
+ model_dims=config["hidden_size"],
+ num_heads=config["num_attention_heads"],
+ max_length=config["max_position_embeddings"],
+ vocab_size=config["vocab_size"],
+ hidden_act=config["hidden_act"],
+ )
+
+
+@dataclass
+class CLIPOutput:
+ # The last_hidden_state indexed at the EOS token and possibly projected if
+ # the model has a projection layer
+ pooled_output: Optional[mx.array] = None
+
+ # The full sequence output of the transformer after the final layernorm
+ last_hidden_state: Optional[mx.array] = None
+
+ # A list of hidden states corresponding to the outputs of the transformer layers
+ hidden_states: Optional[List[mx.array]] = None
+
+
+class CLIPEncoderLayer(nn.Module):
+ """The transformer encoder layer from CLIP."""
+
+ def __init__(self, model_dims: int, num_heads: int, activation: str):
+ super().__init__()
+
+ self.layer_norm1 = nn.LayerNorm(model_dims)
+ self.layer_norm2 = nn.LayerNorm(model_dims)
+
+ self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
+
+ self.linear1 = nn.Linear(model_dims, 4 * model_dims)
+ self.linear2 = nn.Linear(4 * model_dims, model_dims)
+
+ self.act = _ACTIVATIONS[activation]
+
+ def __call__(self, x, attn_mask=None):
+ y = self.layer_norm1(x)
+ y = self.attention(y, y, y, attn_mask)
+ x = y + x
+
+ y = self.layer_norm2(x)
+ y = self.linear1(y)
+ y = self.act(y)
+ y = self.linear2(y)
+ x = y + x
+
+ return x
+
+
+class CLIPTextModel(nn.Module):
+ """Implements the text encoder transformer from CLIP."""
+
+ def __init__(self, config: CLIPTextModelConfig):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
+ self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
+ self.layers = [
+ CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
+ for i in range(config.num_layers)
+ ]
+ self.final_layer_norm = nn.LayerNorm(config.model_dims)
+
+ def _get_mask(self, N, dtype):
+ indices = mx.arange(N)
+ mask = indices[:, None] < indices[None]
+ mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
+ return mask
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for key, w in weights.items():
+ # Remove prefixes
+ if key.startswith("text_model."):
+ key = key[11:]
+ if key.startswith("embeddings."):
+ key = key[11:]
+ if key.startswith("encoder."):
+ key = key[8:]
+
+ # Map attention layers
+ if "self_attn." in key:
+ key = key.replace("self_attn.", "attention.")
+ if "q_proj." in key:
+ key = key.replace("q_proj.", "query_proj.")
+ if "k_proj." in key:
+ key = key.replace("k_proj.", "key_proj.")
+ if "v_proj." in key:
+ key = key.replace("v_proj.", "value_proj.")
+
+ # Map ffn layers
+ if "mlp.fc1" in key:
+ key = key.replace("mlp.fc1", "linear1")
+ if "mlp.fc2" in key:
+ key = key.replace("mlp.fc2", "linear2")
+
+ new_weights[key] = w
+
+ return new_weights
+
+ def __call__(self, x):
+ # Extract some shapes
+ B, N = x.shape
+ eos_tokens = x.argmax(-1)
+
+ # Compute the embeddings
+ x = self.token_embedding(x)
+ x = x + self.position_embedding.weight[:N]
+
+ # Compute the features from the transformer
+ mask = self._get_mask(N, x.dtype)
+ hidden_states = []
+ for l in self.layers:
+ x = l(x, mask)
+ hidden_states.append(x)
+
+ # Apply the final layernorm and return
+ x = self.final_layer_norm(x)
+ last_hidden_state = x
+
+ # Select the EOS token
+ pooled_output = x[mx.arange(len(x)), eos_tokens]
+
+ return CLIPOutput(
+ pooled_output=pooled_output,
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states,
+ )
diff --git a/flux/flux/layers.py b/flux/flux/layers.py
new file mode 100644
index 00000000..12397904
--- /dev/null
+++ b/flux/flux/layers.py
@@ -0,0 +1,302 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import List, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+def _rope(pos: mx.array, dim: int, theta: float):
+ scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
+ omega = 1.0 / (theta**scale)
+ x = pos[..., None] * omega
+ cosx = mx.cos(x)
+ sinx = mx.sin(x)
+ pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
+ pe = pe.reshape(*pe.shape[:-1], 2, 2)
+
+ return pe
+
+
+@partial(mx.compile, shapeless=True)
+def _ab_plus_cd(a, b, c, d):
+ return a * b + c * d
+
+
+def _apply_rope(x, pe):
+ s = x.shape
+ x = x.reshape(*s[:-1], -1, 1, 2)
+ x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
+ return x.reshape(s)
+
+
+def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
+ B, H, L, D = q.shape
+
+ q = _apply_rope(q, pe)
+ k = _apply_rope(k, pe)
+ x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
+
+ return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+
+def timestep_embedding(
+ t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
+):
+ half = dim // 2
+ freqs = mx.arange(0, half, dtype=mx.float32) / half
+ freqs = freqs * (-math.log(max_period))
+ freqs = mx.exp(freqs)
+
+ x = (time_factor * t)[:, None] * freqs[None]
+ x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
+
+ return x.astype(t.dtype)
+
+
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
+ super().__init__()
+
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def __call__(self, ids: mx.array):
+ n_axes = ids.shape[-1]
+ pe = mx.concatenate(
+ [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ axis=-3,
+ )
+
+ return pe[:, None]
+
+
+class MLPEmbedder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ return self.out_layer(nn.silu(self.in_layer(x)))
+
+
+class QKNorm(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = nn.RMSNorm(dim)
+ self.key_norm = nn.RMSNorm(dim)
+
+ def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
+ return self.query_norm(q), self.key_norm(k)
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.norm = QKNorm(head_dim)
+ self.proj = nn.Linear(dim, dim)
+
+ def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
+ H = self.num_heads
+ B, L, _ = x.shape
+ qkv = self.qkv(x)
+ q, k, v = mx.split(qkv, 3, axis=-1)
+ q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ q, k = self.norm(q, k)
+ x = _attention(q, k, v, pe)
+ x = self.proj(x)
+ return x
+
+
+@dataclass
+class ModulationOut:
+ shift: mx.array
+ scale: mx.array
+ gate: mx.array
+
+
+class Modulation(nn.Module):
+ def __init__(self, dim: int, double: bool):
+ super().__init__()
+ self.is_double = double
+ self.multiplier = 6 if double else 3
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
+
+ def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
+ x = self.lin(nn.silu(x))
+ xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
+
+ mod1 = ModulationOut(*xs[:3])
+ mod2 = ModulationOut(*xs[3:]) if self.is_double else None
+
+ return mod1, mod2
+
+
+class DoubleStreamBlock(nn.Module):
+ def __init__(
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
+ ):
+ super().__init__()
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.img_mod = Modulation(hidden_size, double=True)
+ self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.img_attn = SelfAttention(
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
+ )
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.img_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approx="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ self.txt_mod = Modulation(hidden_size, double=True)
+ self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.txt_attn = SelfAttention(
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
+ )
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.txt_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approx="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ def __call__(
+ self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
+ ) -> Tuple[mx.array, mx.array]:
+ B, L, _ = img.shape
+ _, S, _ = txt.shape
+ H = self.num_heads
+
+ img_mod1, img_mod2 = self.img_mod(vec)
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
+
+ # prepare image for attention
+ img_modulated = self.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = self.img_attn.qkv(img_modulated)
+ img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
+ img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ img_q, img_k = self.img_attn.norm(img_q, img_k)
+
+ # prepare txt for attention
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
+ txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
+ txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
+ txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
+ txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
+
+ # run actual attention
+ q = mx.concatenate([txt_q, img_q], axis=2)
+ k = mx.concatenate([txt_k, img_k], axis=2)
+ v = mx.concatenate([txt_v, img_v], axis=2)
+
+ attn = _attention(q, k, v, pe)
+ txt_attn, img_attn = mx.split(attn, [S], axis=1)
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * self.img_mlp(
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
+ )
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * self.txt_mlp(
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
+ )
+
+ return img, txt
+
+
+class SingleStreamBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: Optional[float] = None,
+ ):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ head_dim = hidden_size // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.norm = QKNorm(head_dim)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+
+ self.mlp_act = nn.GELU(approx="tanh")
+ self.modulation = Modulation(hidden_size, double=False)
+
+ def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
+ B, L, _ = x.shape
+ H = self.num_heads
+
+ mod, _ = self.modulation(vec)
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
+
+ q, k, v, mlp = mx.split(
+ self.linear1(x_mod),
+ [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
+ axis=-1,
+ )
+ q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
+ q, k = self.norm(q, k)
+
+ # compute attention
+ y = _attention(q, k, v, pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
+ return x + mod.gate * y
+
+
+class LastLayer(nn.Module):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
+ self.linear = nn.Linear(
+ hidden_size, patch_size * patch_size * out_channels, bias=True
+ )
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def __call__(self, x: mx.array, vec: mx.array):
+ shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
diff --git a/flux/flux/lora.py b/flux/flux/lora.py
new file mode 100644
index 00000000..b0c8ae56
--- /dev/null
+++ b/flux/flux/lora.py
@@ -0,0 +1,76 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+class LoRALinear(nn.Module):
+ @staticmethod
+ def from_base(
+ linear: nn.Linear,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 1.0,
+ ):
+ output_dims, input_dims = linear.weight.shape
+ lora_lin = LoRALinear(
+ input_dims=input_dims,
+ output_dims=output_dims,
+ r=r,
+ dropout=dropout,
+ scale=scale,
+ )
+ lora_lin.linear = linear
+ return lora_lin
+
+ def fuse(self):
+ linear = self.linear
+ bias = "bias" in linear
+ weight = linear.weight
+ dtype = weight.dtype
+
+ output_dims, input_dims = weight.shape
+ fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
+
+ lora_b = self.scale * self.lora_b.T
+ lora_a = self.lora_a.T
+ fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
+ if bias:
+ fused_linear.bias = linear.bias
+
+ return fused_linear
+
+ def __init__(
+ self,
+ input_dims: int,
+ output_dims: int,
+ r: int = 8,
+ dropout: float = 0.0,
+ scale: float = 1.0,
+ bias: bool = False,
+ ):
+ super().__init__()
+
+ # Regular linear layer weights
+ self.linear = nn.Linear(input_dims, output_dims, bias=bias)
+
+ self.dropout = nn.Dropout(p=dropout)
+
+ # Scale for low-rank update
+ self.scale = scale
+
+ # Low rank lora weights
+ scale = 1 / math.sqrt(input_dims)
+ self.lora_a = mx.random.uniform(
+ low=-scale,
+ high=scale,
+ shape=(input_dims, r),
+ )
+ self.lora_b = mx.zeros(shape=(r, output_dims))
+
+ def __call__(self, x):
+ y = self.linear(x)
+ z = (self.dropout(x) @ self.lora_a) @ self.lora_b
+ return y + (self.scale * z).astype(x.dtype)
diff --git a/flux/flux/model.py b/flux/flux/model.py
new file mode 100644
index 00000000..18ea70b0
--- /dev/null
+++ b/flux/flux/model.py
@@ -0,0 +1,134 @@
+# Copyright © 2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .layers import (
+ DoubleStreamBlock,
+ EmbedND,
+ LastLayer,
+ MLPEmbedder,
+ SingleStreamBlock,
+ timestep_embedding,
+)
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+class Flux(nn.Module):
+ def __init__(self, params: FluxParams):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
+ )
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
+ )
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ if params.guidance_embed
+ else nn.Identity()
+ )
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(params.depth)
+ ]
+
+ self.single_blocks = [
+ SingleStreamBlock(
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
+ )
+ for _ in range(params.depth_single_blocks)
+ ]
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for k, w in weights.items():
+ if k.endswith(".scale"):
+ k = k[:-6] + ".weight"
+ for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
+ if f".{seq}." in k:
+ k = k.replace(f".{seq}.", f".{seq}.layers.")
+ break
+ new_weights[k] = w
+ return new_weights
+
+ def __call__(
+ self,
+ img: mx.array,
+ img_ids: mx.array,
+ txt: mx.array,
+ txt_ids: mx.array,
+ timesteps: mx.array,
+ y: mx.array,
+ guidance: Optional[mx.array] = None,
+ ) -> mx.array:
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ img = self.img_in(img)
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError(
+ "Didn't get guidance strength for guidance distilled model."
+ )
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+
+ ids = mx.concatenate([txt_ids, img_ids], axis=1)
+ pe = self.pe_embedder(ids).astype(img.dtype)
+
+ for block in self.double_blocks:
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
+
+ img = mx.concatenate([txt, img], axis=1)
+ for block in self.single_blocks:
+ img = block(img, vec=vec, pe=pe)
+ img = img[:, txt.shape[1] :, ...]
+
+ img = self.final_layer(img, vec)
+
+ return img
diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py
new file mode 100644
index 00000000..3bff1ca2
--- /dev/null
+++ b/flux/flux/sampler.py
@@ -0,0 +1,56 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from functools import lru_cache
+
+import mlx.core as mx
+
+
+class FluxSampler:
+ def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
+ self._base_shift = base_shift
+ self._max_shift = max_shift
+ self._schnell = "schnell" in name
+
+ def _time_shift(self, x, t):
+ x1, x2 = 256, 4096
+ t1, t2 = self._base_shift, self._max_shift
+ exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
+ t = exp_mu / (exp_mu + (1 / t - 1))
+ return t
+
+ @lru_cache
+ def timesteps(
+ self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
+ ):
+ t = mx.linspace(start, stop, num_steps + 1)
+
+ if self._schnell:
+ t = self._time_shift(image_sequence_length, t)
+
+ return t.tolist()
+
+ def random_timesteps(self, B, L, dtype=mx.float32, key=None):
+ if self._schnell:
+ # TODO: Should we upweigh 1 and 0.75?
+ t = mx.random.randint(1, 5, shape=(B,), key=key)
+ t = t.astype(dtype) / 4
+ else:
+ t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
+ t = self._time_shift(L, t)
+
+ return t
+
+ def sample_prior(self, shape, dtype=mx.float32, key=None):
+ return mx.random.normal(shape, dtype=dtype, key=key)
+
+ def add_noise(self, x, t, noise=None, key=None):
+ noise = (
+ noise
+ if noise is not None
+ else mx.random.normal(x.shape, dtype=x.dtype, key=key)
+ )
+ return x * (1 - t) + t * noise
+
+ def step(self, pred, x_t, t, t_prev):
+ return x_t + (t_prev - t) * pred
diff --git a/flux/flux/t5.py b/flux/flux/t5.py
new file mode 100644
index 00000000..cf0515cd
--- /dev/null
+++ b/flux/flux/t5.py
@@ -0,0 +1,244 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import mlx.core as mx
+import mlx.nn as nn
+
+_SHARED_REPLACEMENT_PATTERNS = [
+ (".block.", ".layers."),
+ (".k.", ".key_proj."),
+ (".o.", ".out_proj."),
+ (".q.", ".query_proj."),
+ (".v.", ".value_proj."),
+ ("shared.", "wte."),
+ ("lm_head.", "lm_head.linear."),
+ (".layer.0.layer_norm.", ".ln1."),
+ (".layer.1.layer_norm.", ".ln2."),
+ (".layer.2.layer_norm.", ".ln3."),
+ (".final_layer_norm.", ".ln."),
+ (
+ "layers.0.layer.0.SelfAttention.relative_attention_bias.",
+ "relative_attention_bias.embeddings.",
+ ),
+]
+
+_ENCODER_REPLACEMENT_PATTERNS = [
+ (".layer.0.SelfAttention.", ".attention."),
+ (".layer.1.DenseReluDense.", ".dense."),
+]
+
+
+@dataclass
+class T5Config:
+ vocab_size: int
+ num_layers: int
+ num_heads: int
+ relative_attention_num_buckets: int
+ d_kv: int
+ d_model: int
+ feed_forward_proj: str
+ tie_word_embeddings: bool
+
+ d_ff: Optional[int] = None
+ num_decoder_layers: Optional[int] = None
+ relative_attention_max_distance: int = 128
+ layer_norm_epsilon: float = 1e-6
+
+ @classmethod
+ def from_dict(cls, config):
+ return cls(
+ vocab_size=config["vocab_size"],
+ num_layers=config["num_layers"],
+ num_heads=config["num_heads"],
+ relative_attention_num_buckets=config["relative_attention_num_buckets"],
+ d_kv=config["d_kv"],
+ d_model=config["d_model"],
+ feed_forward_proj=config["feed_forward_proj"],
+ tie_word_embeddings=config["tie_word_embeddings"],
+ d_ff=config.get("d_ff", 4 * config["d_model"]),
+ num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
+ relative_attention_max_distance=config.get(
+ "relative_attention_max_distance", 128
+ ),
+ layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
+ )
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, config: T5Config, bidirectional: bool):
+ self.bidirectional = bidirectional
+ self.num_buckets = config.relative_attention_num_buckets
+ self.max_distance = config.relative_attention_max_distance
+ self.n_heads = config.num_heads
+ self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
+
+ @staticmethod
+ def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
+ num_buckets = num_buckets // 2 if bidirectional else num_buckets
+ max_exact = num_buckets // 2
+
+ abspos = rpos.abs()
+ is_small = abspos < max_exact
+
+ scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
+ buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
+ buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
+
+ buckets = mx.where(is_small, abspos, buckets_large)
+ if bidirectional:
+ buckets = buckets + (rpos > 0) * num_buckets
+ else:
+ buckets = buckets * (rpos < 0)
+
+ return buckets
+
+ def __call__(self, query_length: int, key_length: int, offset: int = 0):
+ """Compute binned relative position bias"""
+ context_position = mx.arange(offset, query_length)[:, None]
+ memory_position = mx.arange(key_length)[None, :]
+
+ # shape (query_length, key_length)
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=self.bidirectional,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+
+ # shape (query_length, key_length, num_heads)
+ values = self.embeddings(relative_position_bucket)
+
+ # shape (num_heads, query_length, key_length)
+ return values.transpose(2, 0, 1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ inner_dim = config.d_kv * config.num_heads
+ self.num_heads = config.num_heads
+ self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
+ self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
+ self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
+ self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
+
+ def __call__(
+ self,
+ queries: mx.array,
+ keys: mx.array,
+ values: mx.array,
+ mask: Optional[mx.array],
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
+ ) -> [mx.array, Tuple[mx.array, mx.array]]:
+ queries = self.query_proj(queries)
+ keys = self.key_proj(keys)
+ values = self.value_proj(values)
+
+ num_heads = self.num_heads
+ B, L, _ = queries.shape
+ _, S, _ = keys.shape
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+ if cache is not None:
+ key_cache, value_cache = cache
+ keys = mx.concatenate([key_cache, keys], axis=3)
+ values = mx.concatenate([value_cache, values], axis=2)
+
+ values_hat = mx.fast.scaled_dot_product_attention(
+ queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
+ )
+ values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+ return self.out_proj(values_hat), (keys, values)
+
+
+class DenseActivation(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ mlp_dims = config.d_ff or config.d_model * 4
+ self.gated = config.feed_forward_proj.startswith("gated")
+ if self.gated:
+ self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
+ self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
+ else:
+ self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
+ self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
+ activation = config.feed_forward_proj.removeprefix("gated-")
+ if activation == "relu":
+ self.act = nn.relu
+ elif activation == "gelu":
+ self.act = nn.gelu
+ elif activation == "silu":
+ self.act = nn.silu
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ def __call__(self, x):
+ if self.gated:
+ hidden_act = self.act(self.wi_0(x))
+ hidden_linear = self.wi_1(x)
+ x = hidden_act * hidden_linear
+ else:
+ x = self.act(self.wi(x))
+ return self.wo(x)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ self.attention = MultiHeadAttention(config)
+ self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dense = DenseActivation(config)
+
+ def __call__(self, x, mask):
+ y = self.ln1(x)
+ y, _ = self.attention(y, y, y, mask=mask)
+ x = x + y
+
+ y = self.ln2(x)
+ y = self.dense(y)
+ return x + y
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, config: T5Config):
+ super().__init__()
+ self.layers = [
+ TransformerEncoderLayer(config) for i in range(config.num_layers)
+ ]
+ self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
+
+ def __call__(self, x: mx.array):
+ pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
+ pos_bias = pos_bias.astype(x.dtype)
+ for layer in self.layers:
+ x = layer(x, mask=pos_bias)
+ return self.ln(x)
+
+
+class T5Encoder(nn.Module):
+ def __init__(self, config: T5Config):
+ self.wte = nn.Embedding(config.vocab_size, config.d_model)
+ self.encoder = TransformerEncoder(config)
+
+ def sanitize(self, weights):
+ new_weights = {}
+ for k, w in weights.items():
+ for old, new in _SHARED_REPLACEMENT_PATTERNS:
+ k = k.replace(old, new)
+ if k.startswith("encoder."):
+ for old, new in _ENCODER_REPLACEMENT_PATTERNS:
+ k = k.replace(old, new)
+ new_weights[k] = w
+ return new_weights
+
+ def __call__(self, inputs: mx.array):
+ return self.encoder(self.wte(inputs))
diff --git a/flux/flux/tokenizers.py b/flux/flux/tokenizers.py
new file mode 100644
index 00000000..796ef389
--- /dev/null
+++ b/flux/flux/tokenizers.py
@@ -0,0 +1,185 @@
+# Copyright © 2024 Apple Inc.
+
+import mlx.core as mx
+import regex
+from sentencepiece import SentencePieceProcessor
+
+
+class CLIPTokenizer:
+ """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
+
+ def __init__(self, bpe_ranks, vocab, max_length=77):
+ self.max_length = max_length
+ self.bpe_ranks = bpe_ranks
+ self.vocab = vocab
+ self.pat = regex.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ regex.IGNORECASE,
+ )
+
+ self._cache = {self.bos: self.bos, self.eos: self.eos}
+
+ @property
+ def bos(self):
+ return "<|startoftext|>"
+
+ @property
+ def bos_token(self):
+ return self.vocab[self.bos]
+
+ @property
+ def eos(self):
+ return "<|endoftext|>"
+
+ @property
+ def eos_token(self):
+ return self.vocab[self.eos]
+
+ def bpe(self, text):
+ if text in self._cache:
+ return self._cache[text]
+
+ unigrams = list(text[:-1]) + [text[-1] + ""]
+ unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+ if not unique_bigrams:
+ return unigrams
+
+ # In every iteration try to merge the two most likely bigrams. If none
+ # was merged we are done.
+ #
+ # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
+ while unique_bigrams:
+ bigram = min(
+ unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
+ )
+ if bigram not in self.bpe_ranks:
+ break
+
+ new_unigrams = []
+ skip = False
+ for a, b in zip(unigrams, unigrams[1:]):
+ if skip:
+ skip = False
+ continue
+
+ if (a, b) == bigram:
+ new_unigrams.append(a + b)
+ skip = True
+
+ else:
+ new_unigrams.append(a)
+
+ if not skip:
+ new_unigrams.append(b)
+
+ unigrams = new_unigrams
+ unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+ self._cache[text] = unigrams
+
+ return unigrams
+
+ def tokenize(self, text, prepend_bos=True, append_eos=True):
+ if isinstance(text, list):
+ return [self.tokenize(t, prepend_bos, append_eos) for t in text]
+
+ # Lower case cleanup and split according to self.pat. Hugging Face does
+ # a much more thorough job here but this should suffice for 95% of
+ # cases.
+ clean_text = regex.sub(r"\s+", " ", text.lower())
+ tokens = regex.findall(self.pat, clean_text)
+
+ # Split the tokens according to the byte-pair merge file
+ bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
+
+ # Map to token ids and return
+ tokens = [self.vocab[t] for t in bpe_tokens]
+ if prepend_bos:
+ tokens = [self.bos_token] + tokens
+ if append_eos:
+ tokens.append(self.eos_token)
+
+ if len(tokens) > self.max_length:
+ tokens = tokens[: self.max_length]
+ if append_eos:
+ tokens[-1] = self.eos_token
+
+ return tokens
+
+ def encode(self, text):
+ if not isinstance(text, list):
+ return self.encode([text])
+
+ tokens = self.tokenize(text)
+ length = max(len(t) for t in tokens)
+ for t in tokens:
+ t.extend([self.eos_token] * (length - len(t)))
+
+ return mx.array(tokens)
+
+
+class T5Tokenizer:
+ def __init__(self, model_file, max_length=512):
+ self._tokenizer = SentencePieceProcessor(model_file)
+ self.max_length = max_length
+
+ @property
+ def pad(self):
+ try:
+ return self._tokenizer.id_to_piece(self.pad_token)
+ except IndexError:
+ return None
+
+ @property
+ def pad_token(self):
+ return self._tokenizer.pad_id()
+
+ @property
+ def bos(self):
+ try:
+ return self._tokenizer.id_to_piece(self.bos_token)
+ except IndexError:
+ return None
+
+ @property
+ def bos_token(self):
+ return self._tokenizer.bos_id()
+
+ @property
+ def eos(self):
+ try:
+ return self._tokenizer.id_to_piece(self.eos_token)
+ except IndexError:
+ return None
+
+ @property
+ def eos_token(self):
+ return self._tokenizer.eos_id()
+
+ def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
+ if isinstance(text, list):
+ return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
+
+ tokens = self._tokenizer.encode(text)
+
+ if prepend_bos and self.bos_token >= 0:
+ tokens = [self.bos_token] + tokens
+ if append_eos and self.eos_token >= 0:
+ tokens.append(self.eos_token)
+ if pad and len(tokens) < self.max_length and self.pad_token >= 0:
+ tokens += [self.pad_token] * (self.max_length - len(tokens))
+
+ return tokens
+
+ def encode(self, text, pad=True):
+ if not isinstance(text, list):
+ return self.encode([text], pad=pad)
+
+ pad_token = self.pad_token if self.pad_token >= 0 else 0
+ tokens = self.tokenize(text, pad=pad)
+ length = max(len(t) for t in tokens)
+ for t in tokens:
+ t.extend([pad_token] * (length - len(t)))
+
+ return mx.array(tokens)
diff --git a/flux/flux/utils.py b/flux/flux/utils.py
new file mode 100644
index 00000000..21db17d3
--- /dev/null
+++ b/flux/flux/utils.py
@@ -0,0 +1,209 @@
+# Copyright © 2024 Apple Inc.
+
+import json
+import os
+from dataclasses import dataclass
+from typing import Optional
+
+import mlx.core as mx
+from huggingface_hub import hf_hub_download
+
+from .autoencoder import AutoEncoder, AutoEncoderParams
+from .clip import CLIPTextModel, CLIPTextModelConfig
+from .model import Flux, FluxParams
+from .t5 import T5Config, T5Encoder
+from .tokenizers import CLIPTokenizer, T5Tokenizer
+
+
+@dataclass
+class ModelSpec:
+ params: FluxParams
+ ae_params: AutoEncoderParams
+ ckpt_path: Optional[str]
+ ae_path: Optional[str]
+ repo_id: Optional[str]
+ repo_flow: Optional[str]
+ repo_ae: Optional[str]
+
+
+configs = {
+ "flux-dev": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+ "flux-schnell": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-schnell",
+ repo_flow="flux1-schnell.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=False,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+}
+
+
+def load_flow_model(name: str, hf_download: bool = True):
+ # Get the safetensors file to load
+ ckpt_path = configs[name].ckpt_path
+
+ # Download if needed
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_flow is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
+
+ # Make the model
+ model = Flux(configs[name].params)
+
+ # Load the checkpoint if needed
+ if ckpt_path is not None:
+ weights = mx.load(ckpt_path)
+ weights = model.sanitize(weights)
+ model.load_weights(list(weights.items()))
+
+ return model
+
+
+def load_ae(name: str, hf_download: bool = True):
+ # Get the safetensors file to load
+ ckpt_path = configs[name].ae_path
+
+ # Download if needed
+ if (
+ ckpt_path is None
+ and configs[name].repo_id is not None
+ and configs[name].repo_ae is not None
+ and hf_download
+ ):
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
+
+ # Make the autoencoder
+ ae = AutoEncoder(configs[name].ae_params)
+
+ # Load the checkpoint if needed
+ if ckpt_path is not None:
+ weights = mx.load(ckpt_path)
+ weights = ae.sanitize(weights)
+ ae.load_weights(list(weights.items()))
+
+ return ae
+
+
+def load_clip(name: str):
+ # Load the config
+ config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
+ with open(config_path) as f:
+ config = CLIPTextModelConfig.from_dict(json.load(f))
+
+ # Make the clip text encoder
+ clip = CLIPTextModel(config)
+
+ # Load the weights
+ ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
+ weights = mx.load(ckpt_path)
+ weights = clip.sanitize(weights)
+ clip.load_weights(list(weights.items()))
+
+ return clip
+
+
+def load_t5(name: str):
+ # Load the config
+ config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
+ with open(config_path) as f:
+ config = T5Config.from_dict(json.load(f))
+
+ # Make the T5 model
+ t5 = T5Encoder(config)
+
+ # Load the weights
+ model_index = hf_hub_download(
+ configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
+ )
+ weight_files = set()
+ with open(model_index) as f:
+ for _, w in json.load(f)["weight_map"].items():
+ weight_files.add(w)
+ weights = {}
+ for w in weight_files:
+ w = f"text_encoder_2/{w}"
+ w = hf_hub_download(configs[name].repo_id, w)
+ weights.update(mx.load(w))
+ weights = t5.sanitize(weights)
+ t5.load_weights(list(weights.items()))
+
+ return t5
+
+
+def load_clip_tokenizer(name: str):
+ vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
+ with open(vocab_file, encoding="utf-8") as f:
+ vocab = json.load(f)
+
+ merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
+ with open(merges_file, encoding="utf-8") as f:
+ bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
+ bpe_merges = [tuple(m.split()) for m in bpe_merges]
+ bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
+
+ return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
+
+
+def load_t5_tokenizer(name: str, pad: bool = True):
+ model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
+ return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
diff --git a/flux/requirements.txt b/flux/requirements.txt
new file mode 100644
index 00000000..792205c9
--- /dev/null
+++ b/flux/requirements.txt
@@ -0,0 +1,7 @@
+mlx>=0.18.1
+huggingface-hub
+regex
+numpy
+tqdm
+Pillow
+sentencepiece
diff --git a/flux/static/dog-r4-g8-1200-512x1024.png b/flux/static/dog-r4-g8-1200-512x1024.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b1ca0e62ac424e0f91fc86feadf0fb47d1bad32
GIT binary patch
literal 772523
zcmV(vL}5gCYx
zL=cf62}T4-5+pMsBASd0lI#W|b45gER5!aLGXQ-uK>!Fp44Wha5dne$5CQ!R0z|Gw
zHuU=l0Nu<;eH=iNk@{_PcL0bW2qH77RUiniwMaHwpA-H-2BNwnf`mRXA|f&q&1N$*
z^{5Cv2uPC5%<2wAX6mE#Dlb?S$|Xj
zKHCcsRfojNg`lq$f^$oB2lec_ZX^+bCW35c`poCr-F1)g{iUav+k>^TTgD`{JHVHJ
zG7G13MIyEA)H-;D^$0N!lM%W#GEuU?64%ve_AdppK~SL1xbKdExzbWM+CZNCGRf
zZ7T>M*XkxfS)RMFk*m8Skp9FNE+e)uS|C9;bk$pTUQ`v
z4Ry6?OD-$uGjHR0x6Zc#h|tH*8{dsWv%|+B?42@oCVOngT)_8>>Up|}Kt{fgQ)*{9
z)?CNK+kGuLbfQ3i$XL;~_aXr8fOfQHwQOVkIbUBi?tyRSN`T!39~j8|f_+~*A~=R6
zY%s&uF@glyslCV~Th>RuEFy`oR+)3>qTgm5_ck~}XyVJi!QAG)R}AvYgg(c%LpDFP
zhJ6irTvMI0EzO1>-Uh0>Q|2*Ku1U7q$8Ua>FEVAfK_a8OWwA1I?bTgc1VpwjOH)qL
z+dL*B)4o8TDgVW0=CUoat$QCRuTuoO>8J)iGuJFfu(iVdHd9`}*71l)*rUBad4d$5
z68*ML`8ehE6?1w0JiNBFn&C*|%eA@Bg`nc2`4qkE%G7?d9ewqLI)shqT^tg#Cg?Hd
z4)h`bcC)(H#A$l?wvhtaalhVc#kpVF%K6OktM@+Meb315CK$Q8xK?x*yVqL&K!tEX
zWXOhNA~t+kaINzV0%xV}*U1&yQaAC{9jwe|%S6p4vWf1>NWG&BY1kk64(TW53ltb+
zM5Jtf&aKf0-0kama7JdGv)7J{s(xi=Gjm0Eb@x}(XZvL99|&|eNU&{=w1|OR$s1Lt
zgX5~>9YA^;0y5>gLXo-FlBISG6=-&~_K{>O>hs^uFCW{x!y%8a(9Wlq{BnHJWNWK*
zg{~_?fHg;pq>lJIIvO_vyiGZsHsnf2V=Dj$i=6KRe9#HxS_)6sUUSbC4Nd?70HN4j
z-y7rZ$R?JqWIhwU%$=
zvj5AKx-A%YIicQr#omVH23~LgK69=Cfkl699qup?`wU)I+e~6qv=yQ=ZV_=%zP&y-2t!}
zk@|Dx$uS?x;F~<(P>JvJsR)4DBR!%+yIJJS
zNnDnUPG_#jXjFs_RlL&|CxPv-`krUl&xL;$_F;A#P6fet^gjXFoLJ-yxcx12X2&S_
z!bI8t-pt~LO}>40369rwe&n0^2-kFo$uereaDE%PynP{+Htq`HzMPGW@c3a&F-=D+79G+
z4kwd}q`fQvy6cO5j2nmHZ8zioJ;~}Ct1v$~H>0mbpra_|#uOMKf9PoAeFd62AHHV1
z2^<25R<4nYvmbMk!vAM}5>8neP{!_K248ethwE9=9)04DCcEX+n>Vw{obylJ%19Y%
zV@Bg)WzG8DkFINT~B5sR>!s%*m6zN@L_35xLgA)Ku1FqnnY-&(+0AV!ImDrAu~`%*+tB1w=yYtc?p<
zav5z!X%i{WNuR9ApA%!jWQ+B{h?Ss&rA+5s$`*vLPKs6RcRJnOoG5uN;5EoTMY@44
zmu8cZhCyKjAu~9rjjn)KTkzUife71}423##vg(v=DSfMuvn=nMHTQCoUOMsStZGI^
z=0e7)>ZRmpZQP-XPlZ;
z=u5AaVAj6>5_8|jUkJeUc^&wV-9{pXGAik;xzzDC=!z19f?o{)QNgfss(I9yK<*v0
z3LCYjo^0(u9R@0ltW!nA=jjBy9_0i%X4kQaKxhzMT~+M9W=NY|RWbtFmQdLq4N2EAUCRDbbnQdso#stp>Q(G5|d>oIW#vE}Ip0gN}sWae0hO
zSE8O7LK$Txi(RGgHxikVE7tY}%<)*~JS#I-=E}Q25mBdd{FfeBaetN9&7ZR(WZmpw
zkQpl@1Cfcq-fNIAd!1(MPvk)}olrTHXRjrc!ts7aurJ;p1{fl8Wp|%*A|qGI;VJs0
z&dZBFbsafXW@Rdw@z
zIaJx2TuMV6pzG_ZjJE;g;55EFIRSHEEt^rJCK0*h`?9K8j^6fO#*e|wr3Z9Gn{f{n
zLe{XHYjtsLCmw*#4)BHva%n9)-alnbryc0oA_?2mUVpia
zUfa0A&GowW6`=>64`=wE!awI*pa1FR%2r1RdFU!kV9aJajmLP;dFKd)ABtR^w70`m
zzeU7a0i#Pd``6{feQpO~H{E~{!TZ!>X=T;B20oD4>0(B>W}f$OhD-q`?c94!9;h&s
zq|L(8g&&9HMSEul)M_sk(jziZ(~v7Hl!&RCs_%sS2H@}(fz>EKT*1BG}n*rixb
z?!IL(mD#duR?^R(YvR!X?}b7{kYWF&r#Y1
ze|2hskfWECaF90GjbcC7;>y@TR|gzVxBW#9SY(5u7g9B&PCJkwy84_08k=9a)?SL`
zVr~{mVH^=bc2{>fQgUHIKCz7qRyp&zQODXk=e;t0=Bq)9QKI9d-PAqGwb(ePB4THD
zRXwG}YsEzw8L6OGi-vnDhpMoaahbjMUQ3H+n-j<|>blkvIG!srkg2ut8rqUXM8tVc
z0Bfzi)`SWQ=!^`LRKtV`mu9ar%00`_RS0r4O0K&zh
z8nYEEyDD?VmT6{H?^FdR*r@2t=&H3=$iM6*(sd2IvbYYr`
z@J`_wnaQamD;n0tq7fAFWYW0LrF>+H^TM0U>SQ+^_n3SFAgW4o%UBvn
zbX8^wUrR#*amzgsuDKe`RkW9hsS#_hv^z>aoKO-Z3#GsAs>oa`*;tt*t6HXj@45u(
zG+>EikP+3b&|yM77mHaz9q5RO(nz4HS1uLe0}Z7Lh+Ns#syfM_sDA76oL)AwtIUxA
z*oagZ7pG3lf=uLOQ%Devs)7?9*O#hlt<}|CoF7(v%2SmYH^EgjyfXApme>)(V2t4r
z3_nj6SUxC}0bKo&t_+U29>z!p4U(*ov!SWeU&hZdu)c
zK5Z#tcdk3}sIYP;?R1W@2F@hg49_XDIUDIFx!1z*3^Nu9vp3`Bbtcv+(wczzTEk#W
ziGWa^FE1rG?tSMiN$D!Tqb0IsU6p&AccBb<#ZW{2jW?g9^W=0gT*h`lRx9^&wQ%K%
z3<6yQ3<{5GU#%r>-Q
zGUHBj15L-@j{2JyF>Mm!JVbO=lY6akXa2bJoTixqrblD|ICHkm0*x<4FyjkQD7a)y
zhksDRp1tW`&xZ9W7l$XQD=~|-M{976%%ez7C0|@D6@*!x9Qs}9UFP!kq@dcGNMtMw
zeA`u`(kbBqqRCtti03>mF*Vm-JqyoU>5>sxONLASsJk7R0PJ@$N55NZ33-eN9$p(8
zF-&8-21aAH6z3d$oa5%ne>1X7^J#)BRvqU&Xml7yLqsA0MX^{bo7evmi4<`Q5Upjc
zM8pKd6ZhHnx)K#>-jViUZ`G8r_6DM>Zze1VP(*P78B`&x+fI!nd+iO9N5r{kq9%XK
zSgJR!U=^#2qz>UFO{w@?RT0^zS{HV}9W5-2Yb_-Pv9hX_W(Zr%shUrIBvPdaEhqj&8gbW>ke1;+Ko1Nq+IQdbJbM)b?+265&={-!Pje(V1vzK
zWoDgo?^;y}U+LU_Pnywo|RR9gm4N%EH;7L^v6SzY>d&S^Pz73|W46MsV7h~d^5=X0D7
z##$>U1kAOBhcsL1TFmv%wS@TkpJdpf3Q&O~?0bNEJo;UMh$Jc0jZ@ktF(Fzgo?gqz
zy8!dM8ro?&;+`CxUn+xg14gU|3+%!6WR-mG6I_j-*ax8{b
zPHH#ZBZF-l3Tif8cyW&Gg0T`VrQ^dzp306dw2rzK|NQ~+CF!vLbG^+rdR#PsdT*Qy
zd@lSM6TP!oKjQeY&wnq=U+nKRX71y=bY!!11KJi!J?e{hEIo9
z?3~lpsu<{;&3)vY>YF%Wl99OY=)lg)Xk^euo|fxvc3%Wni>b#DFH>t#RpzF5cU6XA
zVL_Of7*2TtAbYyrWeQwVg!}XG0;&Mwa&Wg;6E`9ps7&ejT1tJp+12M%RaMsnuTBNE
znBCkvDN0ZJQMEvmO9xb#oXUyI0PBtTy8I&VHOsDb3r@jiB`$+b7+twQIz@afTXBMH
z#ck+TtQzy{=5;RdO-<(==qQN6ae3S$#_qyrurU_dsOszg8E2ycOAL5mbIy>u_U-?$d+Qx!6S
ze!ArIPXrtYm$l(>K@quzo}uJ4pnJF88-%0BxJZO}@M3}?oCx{Zrf-A-t*3+n2a@{4
zIrt!%2-OJX#?s+tvb(C!aT#Int&^i9q`PxChx5Of`}+>v;9VEJ$Az74)+wD3`CpyA
z^0{pV%4>@mA|{Ni9WO+EjSss|m5SJVFU1vxz!2d&m9KVO@^|>m0O+c!V>apg&iSpK
z>jti`)SarMJyJzF{pWICE43qvmA$;)SrM}l(=l<*Q97~LvZtx;Qv%l`u%_hVt9TFX
z$~RpJ7q2u4?cQxnhZqQ1&Y0?DuFzG~dMX>FF=^hXeV)&*I?st~pW&b)(-6ivrq?YO
zud|LD0CZ@D{m*H{wb(*N6gdS{nodz3KwL;(cVzD+8yFD;#dL@Pty~#v`UuE9kh%+s
zbu&_5Ay9v|#q=W#hOpLBEuozJQ=cHqJhwoB(7WqEr!*tF3$9irr@rU20bdurzDiGU
zvyGB?Ll3g0LWH|nRUmhlLT*KI;&ybn941c>PQ)Wrc&udDo~{h!d7kH-FQ$3;Ueyg`
z(ev%Ss!Nb~cgIR4SbAbxw{RA<{0;*SH6v{kYAQZKstD*>0s3BS69LD$
z=|j3o>;Wv!@oJL<)a6-j3_YhDl+$fw-KEXrbkqUuwBhzF(z&KPRN^v@>?RSgZLhEx2VX`B~)+FY0vdJ|O$C}`>
zt5qw|8LM+#p$~y`yxEIhKVv=T?7boaxh7Sk$`NtFMRxEVvMd3uY3{oYw%o}*KjHgv
z<3&v6=PpIMvHnX!fN6XT82%^OjSLR;S-$zr&|OO;649athjcy??2a#s8ABNux6~^$
zGKAvCK;W*hEjRNjp(D5=jh_1DOn$5Ci6PkF$|cNC`&He0E!AX{d$p;>-SoROd&Sb)
z+fX_U&f3h@{#k|758*hLh4af!E$t
zST?bt+GrzJRjs}B0!r`IcdjhSynoTbCv%>HKLOW~l;{c|t2%3K1Xg5s@0Frt!1*R3
zvQZcExW^;aV*4)L0N|<4Oy!DP(bdu?=q}&Q783wB7@9f^)?RvOA40)@*1Cs
zv-5KQstQ=R!IYapskhUPtDl?Qx$%gYZSBx)`;p90-hLu3+O-?c1>i(`X~2
zs~4shG{Rf*=w!izBtduGnXo1q8cz
zwXKs4CWkz1{}PUHn*`N%C5;WFB`ajry9&JO9hgqe@ZtbiYtdfbIvhfs<#+l%XW!np
z$um{VB=hwO?_;#QKo_x2oxtVz_}S7`FF(A=X;pBT=fD^C=$MjSM53CcV1B28>h?a>
zdpLFX10&=4{8a32OfRC!?Ucnr1Fof<6n)E2B2o1%-2Y^pMwUmRaOS$kP9b0-Fkkc9TW;PO
zHSE3jIi>Pz4CzTranq7qF#)AJSw?_lQcYE9T|yAJ=+bP6@=?83R&%GS&m|VtL~`yP
z5zN_W6s6OZX8v`Pr^WhBQuIe5YDy^)?rr2KY0NU36Nk4GauVLcp@n_y*X}Au$ljrl
zHIIm`qAhOeqX7FGxJSbK0;i&JldFU|Ph!W-C%7mV(=`;+**uV>gV76xx`+|6G9$Cj
zneIIfAt6Az?jTn>2}`K4XhL9gqgAVV@BT6i&J2P_-dA~7^Kgb>Q>RRaJ8Fw`
zI91(B#&XD}^VWEt2gKf+%|vwdUdshL04l%0_5n~-jCfloI-7&Ra&&@7yea-lR?SH(
zlYhhmT1)G|A~W2No<4+zhfSiJDuxXm{Zx@j0W|K~RGvIn#$|NEcS?_hD&AE3o+Ub9
z^?PQ2J0lxD7P@+Iq48yoY*p_q%WMu2zkI&l!KL68auI7~cO!DG?CMTNVC5>7
zJh7iJ(bX%L@t{a$I~Xuaergi0Hf2A2%@(uF&=Rq_(gzV})c>b#2AYu7M~#L$l6s*y}20^+a^P;=YO
zW8f6GX`ZL|5NuSnzJ)|&zA^>*i3DraKFUR0&H)0b;|{3n>od{rK#NGQ-Xs$s_F9qA
z#hpulF=n+hO(?Ko4XprN#8Tj-2q97hD+eSIU0=~9cPqQ%*X-(CYpo?C_}Z(=;lAo|
zT1IhxVwP%(HCZjX)QY(;k;on?2ysWJE|igPEY%0xS+`mvv>Cal1o=Nwer8^>iFf+25OPiy^A1-5&sUitaIe
z{D=$9xl3d)3y0QZWQejmSt|jar!s=7UPj84t=wI;_R5TH(y!Vf7*jA&$P`8jdac!6
z%HboT6lgEUp4}xsl~ZOXumKY@Fb=J%_FnEmnX?jtd1)R?R|`W(m3X?oef!YytvXHw
z_l^{>WNahX(&bFC>g_3L-H=T`DqMk+5iX_k>(anS-FW$cVMNJKy)g4ex3I!WT}1
z_FvzwoX-JAhRHJ`sp(
zw5FaN?f%z@cs`FX-D-0erCF)Cch)evE08L5bhWa(G-Vv;R19%Y$JjI>AXJn{RmSG?
z>9~aKwYCC!00e6rIKyaIal6&QYVG$sIBeP%((F&L)GWJKf+(nTGcX{%w_ZN1>*DeNmBAayaUl}XEs}uYc}NLyIMV{x_u65`WIFZ?o8uDrZIsa%
zN~jZYh((zLx)?fOXys%2Bp@+O+=HqEO5R%2TX0q;OjOkFTDfAa8%)O!YkRr>lC5UZZYP*)t*Q!2!-|s?x*1gcTF#(@HXY`q@+QnVjri+}WWt;9is>30
z`DQgMheHgLgLO@Hm{+Fq96$_7OP2`6*Wv7L@nn!H6;3QccK2R}$Mm$iXj4oJAXcUb
zha218zEuxgk+D?kABio2143sUYFLxeJO>4fDt1v3Q-nbAwg**A0D@gC?_)pb=mqsI
z|5+m#l2Vy~Pu)-<52AII;V>swCpjZR7=)SSN^yJYHm-w-iHhwbmXVQ{1E4HuLXw_}
zP;?CkDv-I^;q&4^N7+Ub5M
zx5$qQ-1v_nNI1e)W-i+cmwiJ1j_u8dfhVa_`cE&BrAN^vGrA|aB4HV!k(~$v*nyi}
z*q)oKW_J*YbB=IJm2JAYcIqwka|)hQZ1ONKpvh>mvDf09Y+cxHUrE2_z^|k*g^0ju
z(_Ga-nUPbApiRd#J1No=(7qxwm-%1<3)D<*)r{13k6lv4HU({?A!^*J7-8-~#Kin?
zw-P9C*T@ho`}%mnd>9vb2)?D?i;dn(sH<{L*mVXlwRg8f2LPPs#57o{+@w7{PrY8R
zQ&nYjK=`w|q+l-|F-!-=4}#?sJDsCgZuGUsz(QeZ-I_!1y;Tifxekx~0wSsk3>7YP
zc&Y>wDBg=Mv#l~$*fg!o>?B)SYOXIk{HCNSdO~@nd9Gr1x0;Q$!x0ft&U#`{Z*}~V
zY52ceozaoL@G5HPvfZdxbVd->k`uPlj?Q$m3HRe>2B)v2g%UJx1GrGMFiK?^0T*bi
zl*z>vYzgx__u50oa;_YsV_=MtK6&SPa+DdwepaN&ZdR+JP8TCuyeltMgi0SR3TD+g
zwfCAJCTuF^p*Y&lalf(b)Hz;wOEMXxp%7O6E`mhn`_is+3g0UE`W+l
zb3iO*^CrrH}*TBLtaoftPoC1J&*1LQon2R#si*f0ZHBLDxg~O4ybD
zZ+~f)fcc271(T|@LQ-Slb5dO`rI9I=o6f5kk=`+v`RbJYwAb0C}O`oX9sg&+n|>X2VQzzKw*0
z4&r_kNsLNKPB<9TMU$fCAOdmDIdw8}sIvz0Ga>TyTiJlZdi*e56^Ig15L`RE>VAC+
z8n2SoiwN~QVzg6ZJaIS!r-1*r^uokU?5@cnGv0q~+XV-5p0EfpZ@BjZI}X(*)AYKrj;-IQ1YxhgK&|7hK;RBRMeai5B`5
zGlJu*Iidl>4)b~DbhvN9t4G|w4HC`byrwXeHPo+(AZ5R(&Wgb_y#1pOI)d;Cg
zTi1kn%IAfm$R^<4uYel&WmB~Ly5^^(9zX$RlZq=^41Lsdp1ZAU=%DeXdSHjDE3ik3
zBgcyaRizi#wbEUvZlKjN2l|<{SZRu#+Ew$uV|8^|r$yK7M*rjlSj(n|oK9a7ba#_B
zDK8BW%oVvV)7KxOzpk!xN_c>@@^biEbxvd`Z_%fj8Inm*B_=DPs3z~;JQpPOT^5uh
z)LRt@bpk}BsHd(Ded*G!b3)l5S?AF0E&?Vv@qj8+ya9vMNXH5_%Q#EgKl-R|H^16R
zDG}A7!i4I&_Bmeju72Fz-_!Fj=XTneAl#2r->GT=(gGAB;#eF4vx6A7f@`wm#0~5C
zz7S-Wh!vD7PIb{+NkCn*H}g>o$V4w=)D%ew*sXK4p`mU}i#&x2hE+!=Y_HTsKl3zjc1tsOr7e>$Pp)mNzKk}K=fS_>C12jqKL(zOT#&N(8Mh05hF
zi;=1oRFC+vUo|qO))hHy1mO~CQo*`aGv_LdOY@C*;}gERnN_E%39J=m)eiSX2`bAR
zt`2+am<+%$sY1TN$dKpah|G$6kcy*SES|2}i1_^ZY%5+dT}vr=77Q>wF%BM)r#>@R
zdo~Ic_6UoL!F{XyPAzpQU&WG)fl)><|7rRvUMf9{p-c~z|Locf4DblKDJ6t)mbx?U@y
z6FKP_0OVHx!3w|Q0@Sc!1HIrwC@71@Git3Jpy{)as4F8Q7YNulxXHnX5Ix*oOIN)+
z3l;Mz?%{Cu;^wqDs2&F=Q(R0P>Y&U7vdeu70MvY)2Ez>MabQ9z*zCQ{>~z}coC~9Z
z71IQ-M6sL0%wbw~DI;_3H3?gp#wv|G)xpD45z5_Qk(nxmhxmb+1lQWy%bmQ;6SR@cgGeaG-p@Auc)q0XNJ#xa`l*p{l07_c;_GGPsq}*ZNS#^m(
zoa)9I2P_xP9tvT+1r=t3QX801{jBQj+zM^fHz}+qkOXfh@s7>b>P=?wj_IK)#nd<%NY#09X|*1a6t#+=Ryvw+~H_60`xkZ_7-6Umcb+6)|L
z?9so9Fei=7v>c5Na(6FX)=3`)3DFs=i;5g&xvE-w(?K->=2}p(B@-Ot@`9C$gR+^j
zFLP{M