diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 55a2b5db..2a49ee37 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_additive_causal_mask +from .base import BaseModelArgs, KVCache, create_additive_causal_mask @dataclass @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None attention_bias: bool = False mlp_bias: bool = False rope_theta: float = 10000 @@ -73,7 +73,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -135,7 +135,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 3282dff2..e4a8cc7d 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,10 +1,11 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache +from .su_rope import SuScaledRotaryEmbedding @dataclass @@ -16,10 +17,12 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None + rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None + max_position_embeddings: int = 131072 + original_max_position_embeddings: int = 4096 def __post_init__(self): if self.num_key_value_heads is None: @@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if self.rope_scaling["type"] != "linear": + if self.rope_scaling["type"] not in ["su", "linear"]: print( - "[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false." + "[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false." ) self.rope_scaling = None @@ -43,6 +46,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.num_hidden_layers = args.num_hidden_layers @@ -53,23 +57,34 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "su": + self.rope = SuScaledRotaryEmbedding( + head_dim, + traditional=False, + base=args.rope_theta, + scale=rope_scale, + max_position_embeddings=args.max_position_embeddings, + original_max_position_embeddings=args.original_max_position_embeddings, + short_factor=args.rope_scaling["short_factor"], + long_factor=args.rope_scaling["long_factor"], + ) + else: + if args.rope_scaling and args.rope_scaling["type"] == "linear": + assert isinstance(args.rope_scaling["factor"], float) + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -128,7 +143,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index f3644a56..e0f2d856 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from functools import partial from typing import Dict, Optional, Tuple, Union @@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: int = (64,) + blocksparse_block_size: Tuple[int] = (64,) blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -58,6 +59,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -157,7 +159,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -226,7 +228,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b928de09..fab09003 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -41,6 +41,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -67,7 +68,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -121,7 +122,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index bba02da0..57f154a0 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache from .switch_layers import SwitchGLU @@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs): shared_expert_intermediate_size: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 1000000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -47,6 +47,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -67,7 +68,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ca06bdb1..7b058d8f 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, KVCache @dataclass @@ -43,7 +43,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py new file mode 100644 index 00000000..cdf6ceaf --- /dev/null +++ b/llms/mlx_lm/models/su_rope.py @@ -0,0 +1,79 @@ +import math +from typing import List, Union + +import mlx.core as mx + + +class SuScaledRotaryEmbedding: + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000.0, + scale: float = 1.0, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + short_factor: Union[List[float], float] = 1.0, + long_factor: Union[List[float], float] = 1.0, + ): + """ + Phi3Su Scaled Rotary Embedding layer for Phi-3 models. + + Args: + dims (int): The feature dimensions to be rotated. + traditional (bool, optional): Unused. Default: ``False``. + base (int, optional): Base for the exponential scaling. + scale (float, optional): The scale used to scale the positions. + Default: ``1.0``. + max_position_embeddings (int, optional): The maximum sequence + length that this model was trained with. This is used to determine + the size of the original RoPE embeddings when using long scaling. + Default: ``131072``. + original_max_position_embeddings (int, optional): The maximum + sequence length that this model was trained with. This is used to + determine the size of the original RoPE embeddings when using long + scaling. Default: ``4096``. + short_factor (float or list[float], optional): List of scaling + factors for sequences of length lesser than + ``original_max_position_embeddings``. Default: ``1.0``. + long_factor (float or list[float], optional): List of scaling + factors for sequences of length greater than + ``original_max_position_embeddings``. Default: ``1.0``. + """ + self.inv_freq_short = 1.0 / ( + mx.array(short_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.inv_freq_long = 1.0 / ( + scale + * mx.array(long_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = math.sqrt( + 1 + + math.log(max_position_embeddings / original_max_position_embeddings) + / math.log(original_max_position_embeddings) + ) + + def _get_cos_sin(self, offset, L): + position_ids = mx.arange(offset, offset + L, dtype=mx.float32) + inv_freq = ( + self.inv_freq_long + if (offset + L) > self.original_max_position_embeddings + else self.inv_freq_short + ) + freqs = position_ids[:, None] * inv_freq[None, :] + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) * self.scaling_factor + sin = mx.sin(emb) * self.scaling_factor + return cos, sin + + def __call__(self, x, offset: int = 0): + def _rotate_half(_x): + midpoint = _x.shape[-1] // 2 + x1, x2 = _x[..., :midpoint], _x[..., midpoint:] + return mx.concatenate([-x2, x1], axis=-1) + + cos, sin = self._get_cos_sin(offset, x.shape[2]) + return (x * cos) + (_rotate_half(x) * sin) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 0523be50..97a9b40c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler): self.validate_model_parameters() # Get stop id sequences, if provided - stop_words = self.body.get("stop", []) + stop_words = self.body.get("stop") + stop_words = stop_words or [] stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_id_sequences = [ self.tokenizer.encode(stop_word, add_special_tokens=False) @@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler): if not isinstance(self.max_tokens, int) or self.max_tokens < 0: raise ValueError("max_tokens must be a non-negative integer") - if not isinstance(self.temperature, float) or self.temperature < 0: + if not isinstance(self.temperature, (float, int)) or self.temperature < 0: raise ValueError("temperature must be a non-negative float") - if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1: + if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1: raise ValueError("top_p must be a float between 0 and 1") if ( - not isinstance(self.repetition_penalty, float) + not isinstance(self.repetition_penalty, (float, int)) or self.repetition_penalty < 0 ): raise ValueError("repetition_penalty must be a non-negative float") @@ -527,6 +528,18 @@ def main(): help="Set the MLX cache limit in GB", required=False, ) + parser.add_argument( + "--chat-template", + type=str, + default="", + help="Specify a chat template for the tokenizer", + required=False, + ) + parser.add_argument( + "--use-default-chat-template", + action="store_true", + help="Use the default chat template", + ) args = parser.parse_args() logging.basicConfig( @@ -540,10 +553,17 @@ def main(): # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} + if args.chat_template: + tokenizer_config["chat_template"] = args.chat_template model, tokenizer = load( args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config ) + + if args.use_default_chat_template: + if tokenizer.chat_template is None: + tokenizer.chat_template = tokenizer.default_chat_template + run(args.host, args.port, model, tokenizer) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 4a4fb8f0..5a0e2b5f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -151,6 +151,12 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) for i in indices: # Encode batch batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] + for b in batch: + if b[-1] == tokenizer.eos_token_id: + print("[WARNING] Example already has an EOS token appended") + else: + b.append(tokenizer.eos_token_id) + lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 086e3505..88c3e75e 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.14.2" +__version__ = "0.15.0" diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index 81fa41e3..e04309c1 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -151,8 +151,6 @@ def log_mel_spectrogram( mx.array, shape = (80, n_frames) An array that contains the Mel spectrogram """ - device = mx.default_device() - mx.set_default_device(mx.cpu) if isinstance(audio, str): audio = load_audio(audio) elif not isinstance(audio, mx.array): @@ -170,5 +168,4 @@ def log_mel_spectrogram( log_spec = mx.maximum(mel_spec, 1e-10).log10() log_spec = mx.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - mx.set_default_device(device) return log_spec