Merge branch 'ml-explore:main' into completion_only

This commit is contained in:
Chime Ogbuji
2024-06-15 21:01:05 -04:00
committed by GitHub
11 changed files with 168 additions and 47 deletions

View File

@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_additive_causal_mask from .base import BaseModelArgs, KVCache, create_additive_causal_mask
@dataclass @dataclass
@@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
attention_bias: bool = False attention_bias: bool = False
mlp_bias: bool = False mlp_bias: bool = False
rope_theta: float = 10000 rope_theta: float = 10000
@@ -73,7 +73,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@@ -135,7 +135,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@@ -1,10 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
from .su_rope import SuScaledRotaryEmbedding
@dataclass @dataclass
@@ -16,10 +17,12 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys): if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}") raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear": if self.rope_scaling["type"] not in ["su", "linear"]:
print( print(
"[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false." "[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
) )
self.rope_scaling = None self.rope_scaling = None
@@ -43,6 +46,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers self.num_hidden_layers = args.num_hidden_layers
@@ -53,23 +57,34 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = ( rope_scale = 1.0
1 / args.rope_scaling["factor"] if args.rope_scaling and args.rope_scaling["type"] == "su":
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" self.rope = SuScaledRotaryEmbedding(
else 1 head_dim,
) traditional=False,
self.rope = nn.RoPE( base=args.rope_theta,
head_dim, scale=rope_scale,
traditional=args.rope_traditional, max_position_embeddings=args.max_position_embeddings,
base=args.rope_theta, original_max_position_embeddings=args.original_max_position_embeddings,
scale=rope_scale, short_factor=args.rope_scaling["short_factor"],
) long_factor=args.rope_scaling["long_factor"],
)
else:
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@@ -128,7 +143,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@@ -1,3 +1,4 @@
import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
@@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
@dataclass @dataclass
@@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
layer_norm_epsilon: float layer_norm_epsilon: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
mup_attn_multiplier: float = 1.0 mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0 mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0 mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000 rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0 rope_position_scale: float = 1.0
blocksparse_block_size: int = (64,) blocksparse_block_size: Tuple[int] = (64,)
blocksparse_num_local_blocks: int = 16 blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8 blocksparse_vert_stride: int = 8
@@ -58,6 +59,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads self.n_q_per_kv = n_heads // n_kv_heads
@@ -157,7 +159,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
@dataclass @dataclass
@@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -41,6 +41,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
@@ -67,7 +68,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@@ -121,7 +122,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs):
shared_expert_intermediate_size: int shared_expert_intermediate_size: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@@ -47,6 +47,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
@@ -67,7 +68,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs, KVCache
@dataclass @dataclass
@@ -43,7 +43,7 @@ class Attention(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
B, L, D = x.shape B, L, D = x.shape
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self, self,
x: mx.array, x: mx.array,
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[KVCache] = None,
) -> mx.array: ) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r

View File

@@ -0,0 +1,79 @@
import math
from typing import List, Union
import mlx.core as mx
class SuScaledRotaryEmbedding:
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000.0,
scale: float = 1.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
long_factor: Union[List[float], float] = 1.0,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
Args:
dims (int): The feature dimensions to be rotated.
traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling.
scale (float, optional): The scale used to scale the positions.
Default: ``1.0``.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
Default: ``131072``.
original_max_position_embeddings (int, optional): The maximum
sequence length that this model was trained with. This is used to
determine the size of the original RoPE embeddings when using long
scaling. Default: ``4096``.
short_factor (float or list[float], optional): List of scaling
factors for sequences of length lesser than
``original_max_position_embeddings``. Default: ``1.0``.
long_factor (float or list[float], optional): List of scaling
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
"""
self.inv_freq_short = 1.0 / (
mx.array(short_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.inv_freq_long = 1.0 / (
scale
* mx.array(long_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.original_max_position_embeddings = original_max_position_embeddings
self.scaling_factor = math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
def _get_cos_sin(self, offset, L):
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
inv_freq = (
self.inv_freq_long
if (offset + L) > self.original_max_position_embeddings
else self.inv_freq_short
)
freqs = position_ids[:, None] * inv_freq[None, :]
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return cos, sin
def __call__(self, x, offset: int = 0):
def _rotate_half(_x):
midpoint = _x.shape[-1] // 2
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
return mx.concatenate([-x2, x1], axis=-1)
cos, sin = self._get_cos_sin(offset, x.shape[2])
return (x * cos) + (_rotate_half(x) * sin)

View File

@@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler):
self.validate_model_parameters() self.validate_model_parameters()
# Get stop id sequences, if provided # Get stop id sequences, if provided
stop_words = self.body.get("stop", []) stop_words = self.body.get("stop")
stop_words = stop_words or []
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [ stop_id_sequences = [
self.tokenizer.encode(stop_word, add_special_tokens=False) self.tokenizer.encode(stop_word, add_special_tokens=False)
@@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.max_tokens, int) or self.max_tokens < 0: if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
raise ValueError("max_tokens must be a non-negative integer") raise ValueError("max_tokens must be a non-negative integer")
if not isinstance(self.temperature, float) or self.temperature < 0: if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
raise ValueError("temperature must be a non-negative float") raise ValueError("temperature must be a non-negative float")
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1: if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
raise ValueError("top_p must be a float between 0 and 1") raise ValueError("top_p must be a float between 0 and 1")
if ( if (
not isinstance(self.repetition_penalty, float) not isinstance(self.repetition_penalty, (float, int))
or self.repetition_penalty < 0 or self.repetition_penalty < 0
): ):
raise ValueError("repetition_penalty must be a non-negative float") raise ValueError("repetition_penalty must be a non-negative float")
@@ -527,6 +528,18 @@ def main():
help="Set the MLX cache limit in GB", help="Set the MLX cache limit in GB",
required=False, required=False,
) )
parser.add_argument(
"--chat-template",
type=str,
default="",
help="Specify a chat template for the tokenizer",
required=False,
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig( logging.basicConfig(
@@ -540,10 +553,17 @@ def main():
# Building tokenizer_config # Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.chat_template:
tokenizer_config["chat_template"] = args.chat_template
model, tokenizer = load( model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
) )
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
run(args.host, args.port, model, tokenizer) run(args.host, args.port, model, tokenizer)

View File

@@ -151,6 +151,12 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
for i in indices: for i in indices:
# Encode batch # Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] == tokenizer.eos_token_id:
print("[WARNING] Example already has an EOS token appended")
else:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.14.2" __version__ = "0.15.0"

View File

@@ -151,8 +151,6 @@ def log_mel_spectrogram(
mx.array, shape = (80, n_frames) mx.array, shape = (80, n_frames)
An array that contains the Mel spectrogram An array that contains the Mel spectrogram
""" """
device = mx.default_device()
mx.set_default_device(mx.cpu)
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
elif not isinstance(audio, mx.array): elif not isinstance(audio, mx.array):
@@ -170,5 +168,4 @@ def log_mel_spectrogram(
log_spec = mx.maximum(mel_spec, 1e-10).log10() log_spec = mx.maximum(mel_spec, 1e-10).log10()
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0) log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
mx.set_default_device(device)
return log_spec return log_spec