mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
parent
fbed720d6f
commit
b8a348c1b8
@ -45,20 +45,6 @@ class TextConfig:
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
@ -105,10 +91,6 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if self.repeats > 1:
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@ -119,11 +101,10 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
@ -145,8 +126,8 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.config = config
|
||||
@ -175,7 +156,7 @@ class Llama(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.5.0
|
||||
mlx>=0.8.0
|
||||
numpy
|
||||
transformers
|
||||
torch
|
||||
|
@ -49,20 +49,6 @@ class ModelArgs:
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -107,10 +93,6 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if self.repeats > 1:
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@ -121,11 +103,10 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
@ -147,8 +128,10 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -175,7 +158,7 @@ class LlamaModel(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.0.11
|
||||
mlx>=0.8
|
||||
numpy
|
||||
protobuf==3.20.2
|
||||
sentencepiece
|
||||
|
@ -28,20 +28,6 @@ class ModelArgs:
|
||||
rope_traditional: bool = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -120,8 +106,8 @@ class TransformerBlock(nn.Module):
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -144,7 +130,7 @@ class Llama(nn.Module):
|
||||
self.vocab_size = args.vocab_size
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.0.6
|
||||
mlx>=0.8.0
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
||||
|
@ -26,20 +26,6 @@ class ModelArgs:
|
||||
rope_theta: float = 10000
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -73,9 +59,6 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@ -86,11 +69,10 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
@ -113,8 +95,8 @@ class TransformerBlock(nn.Module):
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -139,7 +121,7 @@ class Mistral(nn.Module):
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.0.6
|
||||
mlx>=0.8.0
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
||||
|
@ -26,20 +26,6 @@ class ModelArgs:
|
||||
moe: dict = None
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -73,9 +59,6 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@ -86,11 +69,10 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
@ -144,8 +126,8 @@ class MOETransformerBlock(nn.Module):
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = MOEFeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -170,7 +152,7 @@ class Mixtral(nn.Module):
|
||||
assert self.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx
|
||||
mlx>=0.8.0
|
||||
sentencepiece
|
||||
torch
|
||||
numpy
|
||||
|
@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -97,7 +96,7 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
@ -125,7 +124,7 @@ class CohereModel(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = LayerNorm(
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
|
@ -23,13 +23,6 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (1.0 + weight) * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@ -37,7 +30,7 @@ class RMSNorm(nn.Module):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
@ -1,80 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def ln_norm(x, eps, weight=None, bias=None):
|
||||
"""
|
||||
Layer normalization for input tensor x.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input tensor.
|
||||
eps (float, optional): Small value to avoid division by zero.
|
||||
weight (np.ndarray, optional): Weight tensor for normalization.
|
||||
bias (np.ndarray, optional): Bias tensor for normalization.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized tensor.
|
||||
"""
|
||||
t = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Compute mean and variance along the last dimension
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
|
||||
# Normalize the input tensor
|
||||
x = (x - means) * mx.rsqrt(var + eps)
|
||||
x = x.astype(t)
|
||||
|
||||
# Apply weight and bias if provided
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
self.affine = affine
|
||||
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.bias = mx.zeros((dims,)) if bias else None
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.affine:
|
||||
if self.bias is not None:
|
||||
return ln_norm(x, self.eps, self.weight, self.bias)
|
||||
else:
|
||||
return ln_norm(x, self.eps, self.weight)
|
||||
else:
|
||||
return ln_norm(x, self.eps)
|
@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -113,8 +112,10 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -141,7 +142,7 @@ class LlamaModel(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -6,7 +6,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -146,7 +145,7 @@ class MixtralSparseMoeBlock(nn.Module):
|
||||
if self.training:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
@ -173,8 +172,10 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self.self_attn = MixtralAttention(args)
|
||||
|
||||
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -199,7 +200,7 @@ class MixtralModel(nn.Module):
|
||||
self.layers = [
|
||||
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
@ -46,8 +45,8 @@ class TransformerBlock(nn.Module):
|
||||
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
||||
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
||||
|
||||
self.att_norm = LayerNorm(dim, affine=False)
|
||||
self.ff_norm = LayerNorm(dim, affine=False)
|
||||
self.att_norm = nn.LayerNorm(dim, affine=False)
|
||||
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
||||
|
||||
head_dim = dim // self.n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
@ -120,7 +119,7 @@ class Transformer(nn.Module):
|
||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
if not self.weight_tying:
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
self.norm = LayerNorm(args.d_model, affine=False)
|
||||
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -122,7 +121,9 @@ class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = PhiAttention(config=config)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = PhiMLP(config)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
@ -137,7 +138,9 @@ class PhiModel(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
@ -7,8 +7,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
@ -116,7 +114,7 @@ class MOE(nn.Module):
|
||||
|
||||
if self.training:
|
||||
ys = []
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.mlp):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
@ -141,7 +139,7 @@ class ParallelBlock(nn.Module):
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = LayerNorm(dims)
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.moe = MOE(config, dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
@ -179,7 +177,7 @@ class Embd(nn.Module):
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.ln = LayerNorm(config.model_dim)
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
|
@ -6,7 +6,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -82,9 +81,6 @@ class Attention(nn.Module):
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
repeats = self.config.n_shared_head
|
||||
key_states = mx.repeat(key_states, repeats, axis=1)
|
||||
value_states = mx.repeat(value_states, repeats, axis=1)
|
||||
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
@ -97,12 +93,14 @@ class Attention(nn.Module):
|
||||
key_states = mx.concatenate([cache[0], key_states], axis=2)
|
||||
value_states = mx.concatenate([cache[1], value_states], axis=2)
|
||||
|
||||
scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2)
|
||||
if attention_mask is not None:
|
||||
scores += attention_mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output), (key_states, value_states)
|
||||
|
||||
|
||||
@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -170,7 +168,7 @@ class PlamoModel(nn.Module):
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -102,9 +101,9 @@ class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
@ -124,7 +123,7 @@ class QwenModel(nn.Module):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -114,8 +113,10 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -142,7 +143,7 @@ class Qwen2Model(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -120,8 +119,10 @@ class DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
@ -138,7 +139,7 @@ class StableLM(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -91,8 +90,8 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.norm_epsilon
|
||||
)
|
||||
self.args = args
|
||||
@ -121,7 +120,7 @@ class Starcoder2Model(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.6
|
||||
mlx>=0.8
|
||||
numpy
|
||||
transformers>=4.38.0
|
||||
protobuf
|
||||
|
@ -132,21 +132,6 @@ class MultiHeadAttention(nn.Module):
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
t = x.dtype
|
||||
output = self._norm(x).astype(t)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
@ -182,8 +167,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
@ -202,7 +187,7 @@ class TransformerEncoder(nn.Module):
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
@ -217,9 +202,9 @@ class TransformerDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(
|
||||
@ -257,7 +242,7 @@ class TransformerDecoder(nn.Module):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, cache=None):
|
||||
|
@ -1,3 +1,3 @@
|
||||
mlx>=0.0.6
|
||||
mlx>=0.8.0
|
||||
transformers
|
||||
numpy
|
||||
|
@ -21,7 +21,9 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache)
|
||||
outputs, cache = model(
|
||||
mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache
|
||||
)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
|
@ -2,8 +2,12 @@
|
||||
|
||||
This is an example of using MLX to fine-tune an LLM with low rank adaptation
|
||||
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
|
||||
(QLoRA).[^qlora] The example works with Llama, Mistral, and Phi-2 style
|
||||
models available on Hugging Face.
|
||||
(QLoRA).[^qlora] The example works with Llama and Mistral style models
|
||||
available on Hugging Face.
|
||||
|
||||
> [!TIP]
|
||||
> For a more fully featured LLM package, checkout [MLX
|
||||
> LM](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm).
|
||||
|
||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||
generate SQL queries from natural language. However, the example is intended to
|
||||
|
@ -1,10 +1,11 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import models
|
||||
import utils
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
@ -12,11 +13,8 @@ from mlx.utils import tree_flatten
|
||||
def quantize(weights, config, args):
|
||||
quantized_config = copy.deepcopy(config)
|
||||
|
||||
# Get model classes
|
||||
model_class, model_args_class = utils._get_classes(config=config)
|
||||
|
||||
# Load the model:
|
||||
model = model_class(model_args_class.from_dict(config))
|
||||
model = models.Model(models.ModelArgs.from_dict(config))
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
# Quantize the model:
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
@ -7,7 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import utils
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from models.lora import LoRALinear
|
||||
from models import LoRALinear
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
@ -12,7 +12,7 @@ import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import utils as lora_utils
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from models.lora import LoRALinear
|
||||
from models import LoRALinear
|
||||
|
||||
|
||||
def build_parser():
|
||||
|
101
lora/models.py
101
lora/models.py
@ -2,17 +2,13 @@
|
||||
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -134,20 +130,6 @@ class LoRALinear(nn.Module):
|
||||
return y + self.scale * z
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@ -192,13 +174,6 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@ -209,11 +184,10 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
@ -235,8 +209,10 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@ -263,7 +239,7 @@ class LlamaModel(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -299,60 +275,3 @@ class Model(nn.Module):
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str):
|
||||
# If the path exists, it will try to load model form it
|
||||
# otherwise download and cache from the hf_repo and cache
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
quantization = config.get("quantization", None)
|
||||
model_args = ModelArgs.from_dict(config)
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
model = Model(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
**quantization,
|
||||
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
|
||||
and m.weight.shape[0] != 8,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return model, tokenizer, config
|
||||
|
||||
|
||||
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
y = prompt
|
||||
cache = None
|
||||
while True:
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits)
|
||||
yield y
|
||||
|
@ -1,15 +0,0 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
@ -1,202 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model = LlamaModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
@ -1,86 +0,0 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def to_linear(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
is_quantized = isinstance(linear, nn.QuantizedLinear)
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
if is_quantized:
|
||||
dtype = mx.float16
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
linear.scales,
|
||||
linear.biases,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
fused_linear.weight = weight + lora_b @ lora_a
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if is_quantized and not de_quantize:
|
||||
fused_linear = nn.QuantizedLinear.from_linear(
|
||||
fused_linear,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
lora_rank: int = 8,
|
||||
bias: bool = False,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, lora_rank),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
dtype = self.linear.weight.dtype
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
dtype = self.linear.scales.dtype
|
||||
y = self.linear(x.astype(dtype))
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
return y + self.scale * z
|
@ -1,253 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
vocab_size: int = 32000
|
||||
max_position_embeddings: int = 4096 * 32
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 14336
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_key_value_heads: int = 8
|
||||
num_local_experts: int = 8
|
||||
rms_norm_eps: float = 1e-5
|
||||
vocab_size: int
|
||||
rope_theta: float = 1e6
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.max_position_embeddings = args.max_position_embeddings
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.num_heads, L, -1])
|
||||
|
||||
if self.repeats > 1:
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
class MixtralBLockSparseTop2MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.hidden_dim = args.hidden_size
|
||||
|
||||
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
||||
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MixtralSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_dim = args.hidden_size
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
|
||||
self.experts = [
|
||||
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.gate(x)
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
|
||||
|
||||
scores = mx.softmax(
|
||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||
axis=-1,
|
||||
).astype(gates.dtype)
|
||||
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = MixtralAttention(args)
|
||||
|
||||
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
h, cache[e] = layer(h, mask, cache[e])
|
||||
|
||||
return self.norm(h), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model = MixtralModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
@ -1,138 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
n_positions: int = 2048
|
||||
vocab_size: int = 51200
|
||||
n_embd: int = 2560
|
||||
n_head: int = 32
|
||||
n_layer: int = 32
|
||||
rotary_dim: int = 32
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, n_head: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
|
||||
self.q_proj = nn.Linear(dims, dims)
|
||||
self.k_proj = nn.Linear(dims, dims)
|
||||
self.v_proj = nn.Linear(dims, dims)
|
||||
self.dense = nn.Linear(dims, dims)
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
n_head = self.n_head
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.dense(values_hat), (keys, values)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.n_embd
|
||||
mlp_dims = dims * 4
|
||||
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
|
||||
self.input_layernorm = LayerNorm(dims)
|
||||
self.mlp = MLP(dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
attn_h, cache = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x, cache
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
|
||||
self.final_layernorm = LayerNorm(config.n_embd)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
return self.final_layernorm(x), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model = Transformer(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y, cache = self.model(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
@ -1,3 +1,3 @@
|
||||
mlx>=0.0.7
|
||||
mlx>=0.8.0
|
||||
transformers
|
||||
numpy
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import json
|
||||
@ -8,40 +8,10 @@ from typing import Generator
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import models.llama as llama
|
||||
import models.mixtral as mixtral
|
||||
import models.phi2 as phi2
|
||||
import models
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Constants
|
||||
MODEL_MAPPING = {
|
||||
"llama": llama,
|
||||
"mistral": llama, # mistral is compatible with llama
|
||||
"phi": phi2,
|
||||
"mixtral": mixtral,
|
||||
}
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
if model_type not in MODEL_MAPPING:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
arch = MODEL_MAPPING[model_type]
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def fetch_from_hub(hf_path: str):
|
||||
model_path = snapshot_download(
|
||||
@ -157,9 +127,8 @@ def load(path_or_hf_repo: str):
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf).items())
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
model_args = models.ModelArgs.from_dict(config)
|
||||
model = models.Model(model_args)
|
||||
if quantization is not None:
|
||||
nn.QuantizedLinear.quantize_module(
|
||||
model,
|
||||
|
@ -1,3 +1,3 @@
|
||||
mlx
|
||||
mlx>=0.8.0
|
||||
numpy
|
||||
transformers
|
||||
|
29
t5/t5.py
29
t5/t5.py
@ -134,21 +134,6 @@ class MultiHeadAttention(nn.Module):
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
t = x.dtype
|
||||
output = self._norm(x).astype(t)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
@ -184,8 +169,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
@ -204,7 +189,7 @@ class TransformerEncoder(nn.Module):
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
@ -219,9 +204,9 @@ class TransformerDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(
|
||||
@ -252,7 +237,7 @@ class TransformerDecoder(nn.Module):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.1
|
||||
mlx>=0.8
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
|
@ -37,11 +37,6 @@ def sinusoids(length, channels, max_timescale=10000):
|
||||
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
@ -98,17 +93,17 @@ class ResidualAttentionBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp1 = nn.Linear(n_state, n_mlp)
|
||||
self.mlp2 = nn.Linear(n_mlp, n_state)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
||||
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
||||
@ -140,7 +135,7 @@ class AudioEncoder(nn.Module):
|
||||
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
|
||||
|
||||
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
self.ln_post = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.gelu(self.conv1(x)).astype(x.dtype)
|
||||
@ -174,7 +169,7 @@ class TextDecoder(nn.Module):
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
self.ln = LayerNorm(n_state)
|
||||
self.ln = nn.LayerNorm(n_state)
|
||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
|
||||
dtype
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user