mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Merge branch 'ml-explore:main' into completion_only
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_additive_causal_mask
|
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
attention_bias: bool = False
|
attention_bias: bool = False
|
||||||
mlp_bias: bool = False
|
mlp_bias: bool = False
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
@@ -73,7 +73,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@@ -135,7 +135,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
@@ -1,10 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -16,10 +17,12 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
original_max_position_embeddings: int = 4096
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
@@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs):
|
|||||||
if not all(key in self.rope_scaling for key in required_keys):
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
if self.rope_scaling["type"] != "linear":
|
if self.rope_scaling["type"] not in ["su", "linear"]:
|
||||||
print(
|
print(
|
||||||
"[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false."
|
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
|
||||||
)
|
)
|
||||||
self.rope_scaling = None
|
self.rope_scaling = None
|
||||||
|
|
||||||
@@ -43,6 +46,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
self.num_hidden_layers = args.num_hidden_layers
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
|
||||||
@@ -53,23 +57,34 @@ class Attention(nn.Module):
|
|||||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
rope_scale = (
|
rope_scale = 1.0
|
||||||
1 / args.rope_scaling["factor"]
|
if args.rope_scaling and args.rope_scaling["type"] == "su":
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
self.rope = SuScaledRotaryEmbedding(
|
||||||
else 1
|
head_dim,
|
||||||
)
|
traditional=False,
|
||||||
self.rope = nn.RoPE(
|
base=args.rope_theta,
|
||||||
head_dim,
|
scale=rope_scale,
|
||||||
traditional=args.rope_traditional,
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
base=args.rope_theta,
|
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||||
scale=rope_scale,
|
short_factor=args.rope_scaling["short_factor"],
|
||||||
)
|
long_factor=args.rope_scaling["long_factor"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||||
|
assert isinstance(args.rope_scaling["factor"], float)
|
||||||
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@@ -128,7 +143,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
@@ -5,7 +6,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -19,14 +20,14 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
layer_norm_epsilon: float
|
layer_norm_epsilon: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
mup_attn_multiplier: float = 1.0
|
mup_attn_multiplier: float = 1.0
|
||||||
mup_use_scaling: bool = True
|
mup_use_scaling: bool = True
|
||||||
mup_embedding_multiplier: float = 10.0
|
mup_embedding_multiplier: float = 10.0
|
||||||
mup_width_multiplier: float = 8.0
|
mup_width_multiplier: float = 8.0
|
||||||
rope_embedding_base: float = 1000000
|
rope_embedding_base: float = 1000000
|
||||||
rope_position_scale: float = 1.0
|
rope_position_scale: float = 1.0
|
||||||
blocksparse_block_size: int = (64,)
|
blocksparse_block_size: Tuple[int] = (64,)
|
||||||
blocksparse_num_local_blocks: int = 16
|
blocksparse_num_local_blocks: int = 16
|
||||||
blocksparse_vert_stride: int = 8
|
blocksparse_vert_stride: int = 8
|
||||||
|
|
||||||
@@ -58,6 +59,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
self.n_q_per_kv = n_heads // n_kv_heads
|
self.n_q_per_kv = n_heads // n_kv_heads
|
||||||
|
|
||||||
@@ -157,7 +159,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@@ -226,7 +228,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
rope_theta: float = 1000000
|
rope_theta: float = 1000000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
@@ -41,6 +41,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
@@ -67,7 +68,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@@ -121,7 +122,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
@@ -5,7 +5,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
shared_expert_intermediate_size: int
|
shared_expert_intermediate_size: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
rope_theta: float = 1000000
|
rope_theta: float = 1000000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
@@ -47,6 +47,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
assert args.num_key_value_heads is not None
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
@@ -67,7 +68,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@@ -159,7 +160,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
@@ -4,7 +4,7 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, KVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -43,7 +43,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[KVCache] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
79
llms/mlx_lm/models/su_rope.py
Normal file
79
llms/mlx_lm/models/su_rope.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class SuScaledRotaryEmbedding:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000.0,
|
||||||
|
scale: float = 1.0,
|
||||||
|
max_position_embeddings: int = 131072,
|
||||||
|
original_max_position_embeddings: int = 4096,
|
||||||
|
short_factor: Union[List[float], float] = 1.0,
|
||||||
|
long_factor: Union[List[float], float] = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims (int): The feature dimensions to be rotated.
|
||||||
|
traditional (bool, optional): Unused. Default: ``False``.
|
||||||
|
base (int, optional): Base for the exponential scaling.
|
||||||
|
scale (float, optional): The scale used to scale the positions.
|
||||||
|
Default: ``1.0``.
|
||||||
|
max_position_embeddings (int, optional): The maximum sequence
|
||||||
|
length that this model was trained with. This is used to determine
|
||||||
|
the size of the original RoPE embeddings when using long scaling.
|
||||||
|
Default: ``131072``.
|
||||||
|
original_max_position_embeddings (int, optional): The maximum
|
||||||
|
sequence length that this model was trained with. This is used to
|
||||||
|
determine the size of the original RoPE embeddings when using long
|
||||||
|
scaling. Default: ``4096``.
|
||||||
|
short_factor (float or list[float], optional): List of scaling
|
||||||
|
factors for sequences of length lesser than
|
||||||
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
|
long_factor (float or list[float], optional): List of scaling
|
||||||
|
factors for sequences of length greater than
|
||||||
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
|
"""
|
||||||
|
self.inv_freq_short = 1.0 / (
|
||||||
|
mx.array(short_factor, dtype=mx.float32)
|
||||||
|
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
)
|
||||||
|
self.inv_freq_long = 1.0 / (
|
||||||
|
scale
|
||||||
|
* mx.array(long_factor, dtype=mx.float32)
|
||||||
|
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
)
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self.scaling_factor = math.sqrt(
|
||||||
|
1
|
||||||
|
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||||
|
/ math.log(original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_cos_sin(self, offset, L):
|
||||||
|
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
|
||||||
|
inv_freq = (
|
||||||
|
self.inv_freq_long
|
||||||
|
if (offset + L) > self.original_max_position_embeddings
|
||||||
|
else self.inv_freq_short
|
||||||
|
)
|
||||||
|
freqs = position_ids[:, None] * inv_freq[None, :]
|
||||||
|
emb = mx.concatenate([freqs, freqs], axis=-1)
|
||||||
|
cos = mx.cos(emb) * self.scaling_factor
|
||||||
|
sin = mx.sin(emb) * self.scaling_factor
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
def _rotate_half(_x):
|
||||||
|
midpoint = _x.shape[-1] // 2
|
||||||
|
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
|
||||||
|
return mx.concatenate([-x2, x1], axis=-1)
|
||||||
|
|
||||||
|
cos, sin = self._get_cos_sin(offset, x.shape[2])
|
||||||
|
return (x * cos) + (_rotate_half(x) * sin)
|
@@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.validate_model_parameters()
|
self.validate_model_parameters()
|
||||||
|
|
||||||
# Get stop id sequences, if provided
|
# Get stop id sequences, if provided
|
||||||
stop_words = self.body.get("stop", [])
|
stop_words = self.body.get("stop")
|
||||||
|
stop_words = stop_words or []
|
||||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
stop_id_sequences = [
|
stop_id_sequences = [
|
||||||
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||||
@@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
||||||
raise ValueError("max_tokens must be a non-negative integer")
|
raise ValueError("max_tokens must be a non-negative integer")
|
||||||
|
|
||||||
if not isinstance(self.temperature, float) or self.temperature < 0:
|
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
|
||||||
raise ValueError("temperature must be a non-negative float")
|
raise ValueError("temperature must be a non-negative float")
|
||||||
|
|
||||||
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
|
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
|
||||||
raise ValueError("top_p must be a float between 0 and 1")
|
raise ValueError("top_p must be a float between 0 and 1")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(self.repetition_penalty, float)
|
not isinstance(self.repetition_penalty, (float, int))
|
||||||
or self.repetition_penalty < 0
|
or self.repetition_penalty < 0
|
||||||
):
|
):
|
||||||
raise ValueError("repetition_penalty must be a non-negative float")
|
raise ValueError("repetition_penalty must be a non-negative float")
|
||||||
@@ -527,6 +528,18 @@ def main():
|
|||||||
help="Set the MLX cache limit in GB",
|
help="Set the MLX cache limit in GB",
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-template",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Specify a chat template for the tokenizer",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-default-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the default chat template",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -540,10 +553,17 @@ def main():
|
|||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
if args.chat_template:
|
||||||
|
tokenizer_config["chat_template"] = args.chat_template
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.use_default_chat_template:
|
||||||
|
if tokenizer.chat_template is None:
|
||||||
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
|
||||||
run(args.host, args.port, model, tokenizer)
|
run(args.host, args.port, model, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -151,6 +151,12 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
for i in indices:
|
for i in indices:
|
||||||
# Encode batch
|
# Encode batch
|
||||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||||
|
for b in batch:
|
||||||
|
if b[-1] == tokenizer.eos_token_id:
|
||||||
|
print("[WARNING] Example already has an EOS token appended")
|
||||||
|
else:
|
||||||
|
b.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
if max(lengths) > max_seq_length:
|
if max(lengths) > max_seq_length:
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.14.2"
|
__version__ = "0.15.0"
|
||||||
|
@@ -151,8 +151,6 @@ def log_mel_spectrogram(
|
|||||||
mx.array, shape = (80, n_frames)
|
mx.array, shape = (80, n_frames)
|
||||||
An array that contains the Mel spectrogram
|
An array that contains the Mel spectrogram
|
||||||
"""
|
"""
|
||||||
device = mx.default_device()
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = load_audio(audio)
|
audio = load_audio(audio)
|
||||||
elif not isinstance(audio, mx.array):
|
elif not isinstance(audio, mx.array):
|
||||||
@@ -170,5 +168,4 @@ def log_mel_spectrogram(
|
|||||||
log_spec = mx.maximum(mel_spec, 1e-10).log10()
|
log_spec = mx.maximum(mel_spec, 1e-10).log10()
|
||||||
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
mx.set_default_device(device)
|
|
||||||
return log_spec
|
return log_spec
|
||||||
|
Reference in New Issue
Block a user