Merge branch 'main' into adding-full-finetuning

This commit is contained in:
Gökdeniz Gülmez
2024-08-21 22:00:30 +02:00
committed by GitHub
34 changed files with 523 additions and 14 deletions

View File

@@ -76,7 +76,12 @@ def setup_arg_parser():
type=int,
default=None,
help="Set the MLX cache limit in GB",
required=False,
)
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
)
return parser
@@ -154,6 +159,7 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
)

View File

@@ -1,6 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, List, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -44,6 +46,100 @@ class KVCache:
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
def state(self):
return self.keys, self.values
class RotatingKVCache:
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
# May not have hit the max size yet, so potentially
# keep growing the cache
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def state(self):
return self.keys, self.values
@dataclass
class BaseModelArgs:
@@ -65,13 +161,17 @@ def create_additive_causal_mask(N: int, offset: int = 0):
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
# Input consists of multiple tokens, create a causal mask so that prior
# tokens do not give attention to later tokens. If a cache is in place
# (because e.g. prompt reuse), offset the mask accordingly.
offset = cache[0].offset if cache is not None and cache[0] is not None else 0
if cache is not None and cache[0] is not None:
c = cache[0]
if isinstance(c, RotatingKVCache):
offset = min(c.max_size - 1, c.offset)
else:
offset = c.offset
else:
offset = 0
mask = create_additive_causal_mask(T, offset)
mask = mask.astype(h.dtype)
else:

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple

View File

@@ -0,0 +1,266 @@
from dataclasses import dataclass
from typing import Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
class DeepseekAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
attention_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
rope_scale = 1.0
if config.rope_scaling and config.rope_scaling["type"] == "linear":
assert isinstance(config.rope_scaling["factor"], float)
rope_scale = 1 / config.rope_scaling["factor"]
self.rope = nn.RoPE(
self.head_dim,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: int | None = None,
intermediate_size: int | None = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
return inds, scores
class DeepseekMoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekAttention(config)
self.mlp = (
DeepseekMoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekMLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for m in ["gate_proj", "down_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from sys import exit
from typing import Optional, Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
import math
from dataclasses import dataclass

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Literal, Optional
@@ -53,6 +55,9 @@ class RecurrentCache:
def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state)
def state(self):
return self._cache
class WindowKVCache:
@@ -80,6 +85,9 @@ class WindowKVCache:
self.values = _update(self.values, values)
return self.keys, self.values
def state(self):
return self.keys, self.values
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import List, Union

View File

@@ -1,3 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx

View File

@@ -103,7 +103,8 @@ def linear_to_lora_layers(
"gemma2",
"starcoder2",
"cohere",
"minicpm"
"minicpm",
"deepseek",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type == "mixtral":

View File

@@ -19,7 +19,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache
from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
@@ -136,6 +136,8 @@ def generate_step(
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -156,6 +158,9 @@ def generate_step(
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@@ -197,7 +202,13 @@ def generate_step(
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
if max_kv_size is not None:
cache = [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist()
@@ -223,6 +234,11 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)
mx.async_eval(y)
@@ -343,8 +359,10 @@ def generate(
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.16.0"
__version__ = "0.17.0"

View File

@@ -4,7 +4,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache
from mlx_lm.models.base import KVCache, RotatingKVCache
class TestModels(unittest.TestCase):
@@ -29,6 +29,64 @@ class TestModels(unittest.TestCase):
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(d, h, max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 2)
k = mx.random.uniform(shape=(b, h, 5, d))
v = mx.random.uniform(shape=(b, h, 5, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
k = mx.random.uniform(shape=(b, h, 4, d))
v = mx.random.uniform(shape=(b, h, 4, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
idx = 0
for _ in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
idx += 1
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
v = mx.random.uniform(shape=(b, h, 20, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
# A bunch of small updates
self.assertEqual(cache.offset, 20)
idx = 2
for i in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
self.assertEqual(cache.offset, 21 + i)
idx += 1
if idx >= 8:
idx = 2
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)