Merge branch 'ml-explore:main' into completion_only

This commit is contained in:
Chime Ogbuji 2024-06-15 21:01:05 -04:00 committed by GitHub
commit 8c1d33d523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 168 additions and 47 deletions

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_additive_causal_mask
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
@dataclass
@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
@ -73,7 +73,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -135,7 +135,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -1,10 +1,11 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
from .su_rope import SuScaledRotaryEmbedding
@dataclass
@ -16,10 +17,12 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
def __post_init__(self):
if self.num_key_value_heads is None:
@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
if self.rope_scaling["type"] not in ["su", "linear"]:
print(
"[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false."
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
)
self.rope_scaling = None
@ -43,6 +46,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers
@ -53,23 +57,34 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "su":
self.rope = SuScaledRotaryEmbedding(
head_dim,
traditional=False,
base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -128,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = (64,)
blocksparse_block_size: Tuple[int] = (64,)
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@ -58,6 +59,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@ -157,7 +159,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@ -41,6 +41,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
@ -67,7 +68,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -121,7 +122,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
from .switch_layers import SwitchGLU
@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs):
shared_expert_intermediate_size: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
@ -47,6 +47,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
@ -67,7 +68,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -4,7 +4,7 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, KVCache
@dataclass
@ -43,7 +43,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape
@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r

View File

@ -0,0 +1,79 @@
import math
from typing import List, Union
import mlx.core as mx
class SuScaledRotaryEmbedding:
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000.0,
scale: float = 1.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
long_factor: Union[List[float], float] = 1.0,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
Args:
dims (int): The feature dimensions to be rotated.
traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling.
scale (float, optional): The scale used to scale the positions.
Default: ``1.0``.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
Default: ``131072``.
original_max_position_embeddings (int, optional): The maximum
sequence length that this model was trained with. This is used to
determine the size of the original RoPE embeddings when using long
scaling. Default: ``4096``.
short_factor (float or list[float], optional): List of scaling
factors for sequences of length lesser than
``original_max_position_embeddings``. Default: ``1.0``.
long_factor (float or list[float], optional): List of scaling
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
"""
self.inv_freq_short = 1.0 / (
mx.array(short_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.inv_freq_long = 1.0 / (
scale
* mx.array(long_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.original_max_position_embeddings = original_max_position_embeddings
self.scaling_factor = math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
def _get_cos_sin(self, offset, L):
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
inv_freq = (
self.inv_freq_long
if (offset + L) > self.original_max_position_embeddings
else self.inv_freq_short
)
freqs = position_ids[:, None] * inv_freq[None, :]
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return cos, sin
def __call__(self, x, offset: int = 0):
def _rotate_half(_x):
midpoint = _x.shape[-1] // 2
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
return mx.concatenate([-x2, x1], axis=-1)
cos, sin = self._get_cos_sin(offset, x.shape[2])
return (x * cos) + (_rotate_half(x) * sin)

View File

@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler):
self.validate_model_parameters()
# Get stop id sequences, if provided
stop_words = self.body.get("stop", [])
stop_words = self.body.get("stop")
stop_words = stop_words or []
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
self.tokenizer.encode(stop_word, add_special_tokens=False)
@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler):
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
raise ValueError("max_tokens must be a non-negative integer")
if not isinstance(self.temperature, float) or self.temperature < 0:
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
raise ValueError("temperature must be a non-negative float")
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
raise ValueError("top_p must be a float between 0 and 1")
if (
not isinstance(self.repetition_penalty, float)
not isinstance(self.repetition_penalty, (float, int))
or self.repetition_penalty < 0
):
raise ValueError("repetition_penalty must be a non-negative float")
@ -527,6 +528,18 @@ def main():
help="Set the MLX cache limit in GB",
required=False,
)
parser.add_argument(
"--chat-template",
type=str,
default="",
help="Specify a chat template for the tokenizer",
required=False,
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
args = parser.parse_args()
logging.basicConfig(
@ -540,10 +553,17 @@ def main():
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.chat_template:
tokenizer_config["chat_template"] = args.chat_template
model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
run(args.host, args.port, model, tokenizer)

View File

@ -151,6 +151,12 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] == tokenizer.eos_token_id:
print("[WARNING] Example already has an EOS token appended")
else:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.14.2"
__version__ = "0.15.0"

View File

@ -151,8 +151,6 @@ def log_mel_spectrogram(
mx.array, shape = (80, n_frames)
An array that contains the Mel spectrogram
"""
device = mx.default_device()
mx.set_default_device(mx.cpu)
if isinstance(audio, str):
audio = load_audio(audio)
elif not isinstance(audio, mx.array):
@ -170,5 +168,4 @@ def log_mel_spectrogram(
log_spec = mx.maximum(mel_spec, 1e-10).log10()
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
mx.set_default_device(device)
return log_spec