From c5da302fc4f9513023fae47f6f9b25e50388d320 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Jun 2024 08:59:44 -0700 Subject: [PATCH 1/8] gpu featurization (#824) --- whisper/mlx_whisper/audio.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index 81fa41e3..e04309c1 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -151,8 +151,6 @@ def log_mel_spectrogram( mx.array, shape = (80, n_frames) An array that contains the Mel spectrogram """ - device = mx.default_device() - mx.set_default_device(mx.cpu) if isinstance(audio, str): audio = load_audio(audio) elif not isinstance(audio, mx.array): @@ -170,5 +168,4 @@ def log_mel_spectrogram( log_spec = mx.maximum(mel_spec, 1e-10).log10() log_spec = mx.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - mx.set_default_device(device) return log_spec From bb8227f18197202e3196ff7fbb1f83e329df7768 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 10 Jun 2024 14:47:31 -0700 Subject: [PATCH 2/8] Correct type annotation of llama.ModelArgs.num_key_value_heads (#827) --- llms/mlx_lm/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 55a2b5db..e7f4f16a 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None attention_bias: bool = False mlp_bias: bool = False rope_theta: float = 10000 From a54dfd698e70dc3ef2b79b80b05968321abc1a05 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 10 Jun 2024 15:18:34 -0700 Subject: [PATCH 3/8] Correct the type annotation of cache in llama.py (#828) * Update * Fix isort --- llms/mlx_lm/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index e7f4f16a..2a49ee37 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, KVCache, create_additive_causal_mask @dataclass @@ -73,7 +73,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -135,7 +135,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r From fda41545a6d85f951a6967a1002e8bef1e9f436b Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:20:04 +0900 Subject: [PATCH 4/8] Su-RoPE(Rotary Position Embedding) for Phi-3 (#813) * Su-RoPE * nits * Update su_rope.py * Update su_rope.py Per GPT4: "The error TypeError: 'type' object is not subscriptable is caused by using the type hint list[float] in a version of Python that does not support it. This syntax is only available in Python 3.9 and later." * Ran isort --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/phi3.py | 39 +++++++++++------ llms/mlx_lm/models/su_rope.py | 79 +++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 llms/mlx_lm/models/su_rope.py diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 3282dff2..b30456fd 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .su_rope import SuScaledRotaryEmbedding @dataclass @@ -20,6 +21,8 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + max_position_embeddings: int = 131072 + original_max_position_embeddings: int = 4096 def __post_init__(self): if self.num_key_value_heads is None: @@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if self.rope_scaling["type"] != "linear": + if self.rope_scaling["type"] not in ["su", "linear"]: print( - "[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false." + "[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false." ) self.rope_scaling = None @@ -53,17 +56,27 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "su": + self.rope = SuScaledRotaryEmbedding( + head_dim, + traditional=False, + base=args.rope_theta, + scale=rope_scale, + max_position_embeddings=args.max_position_embeddings, + original_max_position_embeddings=args.original_max_position_embeddings, + short_factor=args.rope_scaling["short_factor"], + long_factor=args.rope_scaling["long_factor"], + ) + else: + if args.rope_scaling and args.rope_scaling["type"] == "linear": + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) def __call__( self, diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py new file mode 100644 index 00000000..cdf6ceaf --- /dev/null +++ b/llms/mlx_lm/models/su_rope.py @@ -0,0 +1,79 @@ +import math +from typing import List, Union + +import mlx.core as mx + + +class SuScaledRotaryEmbedding: + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000.0, + scale: float = 1.0, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + short_factor: Union[List[float], float] = 1.0, + long_factor: Union[List[float], float] = 1.0, + ): + """ + Phi3Su Scaled Rotary Embedding layer for Phi-3 models. + + Args: + dims (int): The feature dimensions to be rotated. + traditional (bool, optional): Unused. Default: ``False``. + base (int, optional): Base for the exponential scaling. + scale (float, optional): The scale used to scale the positions. + Default: ``1.0``. + max_position_embeddings (int, optional): The maximum sequence + length that this model was trained with. This is used to determine + the size of the original RoPE embeddings when using long scaling. + Default: ``131072``. + original_max_position_embeddings (int, optional): The maximum + sequence length that this model was trained with. This is used to + determine the size of the original RoPE embeddings when using long + scaling. Default: ``4096``. + short_factor (float or list[float], optional): List of scaling + factors for sequences of length lesser than + ``original_max_position_embeddings``. Default: ``1.0``. + long_factor (float or list[float], optional): List of scaling + factors for sequences of length greater than + ``original_max_position_embeddings``. Default: ``1.0``. + """ + self.inv_freq_short = 1.0 / ( + mx.array(short_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.inv_freq_long = 1.0 / ( + scale + * mx.array(long_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = math.sqrt( + 1 + + math.log(max_position_embeddings / original_max_position_embeddings) + / math.log(original_max_position_embeddings) + ) + + def _get_cos_sin(self, offset, L): + position_ids = mx.arange(offset, offset + L, dtype=mx.float32) + inv_freq = ( + self.inv_freq_long + if (offset + L) > self.original_max_position_embeddings + else self.inv_freq_short + ) + freqs = position_ids[:, None] * inv_freq[None, :] + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) * self.scaling_factor + sin = mx.sin(emb) * self.scaling_factor + return cos, sin + + def __call__(self, x, offset: int = 0): + def _rotate_half(_x): + midpoint = _x.shape[-1] // 2 + x1, x2 = _x[..., :midpoint], _x[..., midpoint:] + return mx.concatenate([-x2, x1], axis=-1) + + cos, sin = self._get_cos_sin(offset, x.shape[2]) + return (x * cos) + (_rotate_half(x) * sin) From 6da07fb1b03acba3bedc9d70b56e9d091e2af651 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Wed, 12 Jun 2024 06:53:55 -0700 Subject: [PATCH 5/8] make models/phi3.py and models/phi3small.py compatible with mypy (#833) --- llms/mlx_lm/models/phi3.py | 14 ++++++++------ llms/mlx_lm/models/phi3small.py | 12 +++++++----- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index b30456fd..e4a8cc7d 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache from .su_rope import SuScaledRotaryEmbedding @@ -17,10 +17,10 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None + rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None max_position_embeddings: int = 131072 original_max_position_embeddings: int = 4096 @@ -46,6 +46,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.num_hidden_layers = args.num_hidden_layers @@ -70,6 +71,7 @@ class Attention(nn.Module): ) else: if args.rope_scaling and args.rope_scaling["type"] == "linear": + assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] self.rope = nn.RoPE( head_dim, @@ -82,7 +84,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -141,7 +143,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index f3644a56..e0f2d856 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from functools import partial from typing import Dict, Optional, Tuple, Union @@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: int = (64,) + blocksparse_block_size: Tuple[int] = (64,) blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -58,6 +59,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -157,7 +159,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -226,7 +228,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r From 3cc58e17fbf21d3997061c3b63994275e4dcd454 Mon Sep 17 00:00:00 2001 From: Nada Amin Date: Wed, 12 Jun 2024 10:17:06 -0400 Subject: [PATCH 6/8] Tweaks to run dspy-produced calls to the server, with gemma template. (#810) * Tweaks to run dspy-produced calls to the server, with gemma template. following comment https://github.com/stanfordnlp/dspy/issues/385#issuecomment-1998939936 can try it out with: ```sh python -m server --model mlx-community/gemma-1.1-7b-it-4bit --port 1143 ``` modulo patching the relative imports in server.py ``` -from .tokenizer_utils import TokenizerWrapper -from .utils import generate_step, load +from mlx_lm.tokenizer_utils import TokenizerWrapper +from mlx_lm.utils import generate_step, load ``` and then, ont the dspy side: ```python import dspy lm = dspy.OpenAI(model_type="chat", api_base="http://localhost:11434/v1/", api_key="not_needed", max_tokens=250) lm("hello") ``` * simpler way to validate float or int * remove logic that works around incompatible templates, too gemma specific * tweak messages for common denominator * use generate.py workaround for DBXR * put behind flag * oops * Solution to chat template issue: pass in a custom template! The template should likely adhere to the OpenAI chat model. Here is such a template for Gemma. --chat-template "{{ bos_token }}{% set extra_system = '' %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{% if role == 'system' %}{% set extra_system = extra_system + message['content'] %}{% else %}{% if role == 'user' and extra_system %}{% set message_system = 'System: ' + extra_system %}{% else %}{% set message_system = '' %}{% endif %}{{ '' + role + '\n' + message_system + message['content'] | trim + '\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}" * remove convoluted solution * Tweak for when None is provided explicitly, and must be set to [] too. For example, the outlines library provides None explicitly. * style --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/server.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 0523be50..97a9b40c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler): self.validate_model_parameters() # Get stop id sequences, if provided - stop_words = self.body.get("stop", []) + stop_words = self.body.get("stop") + stop_words = stop_words or [] stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_id_sequences = [ self.tokenizer.encode(stop_word, add_special_tokens=False) @@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler): if not isinstance(self.max_tokens, int) or self.max_tokens < 0: raise ValueError("max_tokens must be a non-negative integer") - if not isinstance(self.temperature, float) or self.temperature < 0: + if not isinstance(self.temperature, (float, int)) or self.temperature < 0: raise ValueError("temperature must be a non-negative float") - if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1: + if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: raise ValueError("top_p must be a float between 0 and 1") if ( - not isinstance(self.repetition_penalty, float) + not isinstance(self.repetition_penalty, (float, int)) or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") @@ -527,6 +528,18 @@ def main(): help="Set the MLX cache limit in GB", required=False, ) + parser.add_argument( + "--chat-template", + type=str, + default="", + help="Specify a chat template for the tokenizer", + required=False, + ) + parser.add_argument( + "--use-default-chat-template", + action="store_true", + help="Use the default chat template", + ) args = parser.parse_args() logging.basicConfig( @@ -540,10 +553,17 @@ def main(): # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + if args.chat_template: + tokenizer_config["chat_template"] = args.chat_template model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) + + if args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + run(args.host, args.port, model, tokenizer) From d8b073e3a71a89b80d58f02f48cb17711642b2d1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 12 Jun 2024 07:44:21 -0700 Subject: [PATCH 7/8] Add eos token to lora fine-tunes (#818) * add eos token to lora fine-tunes * Comment --- llms/mlx_lm/tuner/trainer.py | 6 ++++++ llms/mlx_lm/version.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index feecf523..24fcc5c6 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -92,6 +92,12 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) for i in indices: # Encode batch batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] + for b in batch: + if b[-1] == tokenizer.eos_token_id: + print("[WARNING] Example already has an EOS token appended") + else: + b.append(tokenizer.eos_token_id) + lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 086e3505..88c3e75e 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.14.2" +__version__ = "0.15.0" From a7598e9456c6455a07ff4905712c2ea3cfcd52db Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Fri, 14 Jun 2024 09:44:50 -0700 Subject: [PATCH 8/8] Fix mypy errors with models/{qwen2,qwen2_moe,startcoder2}.py (#835) * Fix starcoder.py * Fix qwen2 * Remvoe unnecessary assert not None --- llms/mlx_lm/models/qwen2.py | 9 +++++---- llms/mlx_lm/models/qwen2_moe.py | 9 +++++---- llms/mlx_lm/models/starcoder2.py | 6 +++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b928de09..fab09003 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -41,6 +41,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -67,7 +68,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -121,7 +122,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index bba02da0..57f154a0 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache from .switch_layers import SwitchGLU @@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs): shared_expert_intermediate_size: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -47,6 +47,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -67,7 +68,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ca06bdb1..7b058d8f 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -43,7 +43,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r