Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun 2024-03-23 07:13:51 -07:00 committed by GitHub
parent fbed720d6f
commit b8a348c1b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 144 additions and 1155 deletions

View File

@ -45,20 +45,6 @@ class TextConfig:
raise ValueError("rope_scaling 'type' currently only supports 'linear'") raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, config: TextConfig): def __init__(self, config: TextConfig):
super().__init__() super().__init__()
@ -105,10 +91,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -119,11 +101,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)
@ -145,8 +126,8 @@ class TransformerBlock(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Attention(config) self.self_attn = Attention(config)
self.mlp = MLP(config.hidden_size, config.intermediate_size) self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
self.config = config self.config = config
@ -175,7 +156,7 @@ class Llama(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(config=config) for _ in range(config.num_hidden_layers) TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
] ]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__( def __call__(
self, self,

View File

@ -1,4 +1,4 @@
mlx>=0.5.0 mlx>=0.8.0
numpy numpy
transformers transformers
torch torch

View File

@ -49,20 +49,6 @@ class ModelArgs:
) )
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -107,10 +93,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -121,11 +103,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)
@ -147,8 +128,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args self.args = args
def __call__( def __call__(
@ -175,7 +158,7 @@ class LlamaModel(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__( def __call__(
self, self,

View File

@ -1,4 +1,4 @@
mlx>=0.0.11 mlx>=0.8
numpy numpy
protobuf==3.20.2 protobuf==3.20.2
sentencepiece sentencepiece

View File

@ -28,20 +28,6 @@ class ModelArgs:
rope_traditional: bool = True rope_traditional: bool = True
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -120,8 +106,8 @@ class TransformerBlock(nn.Module):
self.dim = args.dim self.dim = args.dim
self.attention = Attention(args) self.attention = Attention(args)
self.feed_forward = FeedForward(args=args) self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.args = args self.args = args
def __call__( def __call__(
@ -144,7 +130,7 @@ class Llama(nn.Module):
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(self, x): def __call__(self, x):

View File

@ -1,4 +1,4 @@
mlx>=0.0.6 mlx>=0.8.0
sentencepiece sentencepiece
torch torch
numpy numpy

View File

@ -26,20 +26,6 @@ class ModelArgs:
rope_theta: float = 10000 rope_theta: float = 10000
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -73,9 +59,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -86,11 +69,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values) return self.wo(output), (keys, values)
@ -113,8 +95,8 @@ class TransformerBlock(nn.Module):
self.dim = args.dim self.dim = args.dim
self.attention = Attention(args) self.attention = Attention(args)
self.feed_forward = FeedForward(args=args) self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.args = args self.args = args
def __call__( def __call__(
@ -139,7 +121,7 @@ class Mistral(nn.Module):
assert self.vocab_size > 0 assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__( def __call__(

View File

@ -1,4 +1,4 @@
mlx>=0.0.6 mlx>=0.8.0
sentencepiece sentencepiece
torch torch
numpy numpy

View File

@ -26,20 +26,6 @@ class ModelArgs:
moe: dict = None moe: dict = None
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -73,9 +59,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -86,11 +69,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values) return self.wo(output), (keys, values)
@ -144,8 +126,8 @@ class MOETransformerBlock(nn.Module):
self.dim = args.dim self.dim = args.dim
self.attention = Attention(args) self.attention = Attention(args)
self.feed_forward = MOEFeedForward(args=args) self.feed_forward = MOEFeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.args = args self.args = args
def __call__( def __call__(
@ -170,7 +152,7 @@ class Mixtral(nn.Module):
assert self.vocab_size > 0 assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)] self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__( def __call__(

View File

@ -1,4 +1,4 @@
mlx mlx>=0.8.0
sentencepiece sentencepiece
torch torch
numpy numpy

View File

@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass @dataclass
@ -97,7 +96,7 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm( self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
) )
self.args = args self.args = args
@ -125,7 +124,7 @@ class CohereModel(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = LayerNorm( self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
) )

View File

@ -23,13 +23,6 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False rope_traditional: bool = False
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (1.0 + weight) * x.astype(weight.dtype)
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):
super().__init__() super().__init__()
@ -37,7 +30,7 @@ class RMSNorm(nn.Module):
self.eps = eps self.eps = eps
def __call__(self, x): def __call__(self, x):
return rms_norm(x, self.weight, self.eps) return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module): class Attention(nn.Module):

View File

@ -1,80 +0,0 @@
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
"""
Layer normalization for input tensor x.
Args:
x (np.ndarray): Input tensor.
eps (float, optional): Small value to avoid division by zero.
weight (np.ndarray, optional): Weight tensor for normalization.
bias (np.ndarray, optional): Bias tensor for normalization.
Returns:
np.ndarray: Normalized tensor.
"""
t = x.dtype
x = x.astype(mx.float32)
# Compute mean and variance along the last dimension
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
# Normalize the input tensor
x = (x - means) * mx.rsqrt(var + eps)
x = x.astype(t)
# Apply weight and bias if provided
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
class LayerNorm(nn.Module):
def __init__(
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
):
super().__init__()
self.eps = eps
self.dims = dims
self.affine = affine
if affine:
self.weight = mx.ones((dims,))
self.bias = mx.zeros((dims,)) if bias else None
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
if self.affine:
if self.bias is not None:
return ln_norm(x, self.eps, self.weight, self.bias)
else:
return ln_norm(x, self.eps, self.weight)
else:
return ln_norm(x, self.eps)

View File

@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -113,8 +112,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args self.args = args
def __call__( def __call__(
@ -141,7 +142,7 @@ class LlamaModel(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__( def __call__(
self, self,

View File

@ -6,7 +6,6 @@ import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -146,7 +145,7 @@ class MixtralSparseMoeBlock(nn.Module):
if self.training: if self.training:
mx.eval(inds) mx.eval(inds)
inds = np.array(inds) inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1])) y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts): for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e)) idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0: if idx1.size == 0:
@ -173,8 +172,10 @@ class MixtralDecoderLayer(nn.Module):
self.self_attn = MixtralAttention(args) self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args) self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__( def __call__(
self, self,
@ -199,7 +200,7 @@ class MixtralModel(nn.Module):
self.layers = [ self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers) MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__( def __call__(
self, self,

View File

@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
try: try:
import hf_olmo import hf_olmo
@ -46,8 +45,8 @@ class TransformerBlock(nn.Module):
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False) self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False) self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
self.att_norm = LayerNorm(dim, affine=False) self.att_norm = nn.LayerNorm(dim, affine=False)
self.ff_norm = LayerNorm(dim, affine=False) self.ff_norm = nn.LayerNorm(dim, affine=False)
head_dim = dim // self.n_heads head_dim = dim // self.n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
@ -120,7 +119,7 @@ class Transformer(nn.Module):
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
if not self.weight_tying: if not self.weight_tying:
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = LayerNorm(args.d_model, affine=False) self.norm = nn.LayerNorm(args.d_model, affine=False)
def __call__( def __call__(
self, self,

View File

@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass @dataclass
@ -122,7 +121,9 @@ class PhiDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.self_attn = PhiAttention(config=config) self.self_attn = PhiAttention(config=config)
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.mlp = PhiMLP(config) self.mlp = PhiMLP(config)
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
@ -137,7 +138,9 @@ class PhiModel(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)] self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache): def __call__(self, x, cache):
x = self.embed_tokens(x) x = self.embed_tokens(x)

View File

@ -7,8 +7,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .layers import LayerNorm
@dataclass @dataclass
class ModelArgs: class ModelArgs:
@ -116,7 +114,7 @@ class MOE(nn.Module):
if self.training: if self.training:
ys = [] ys = []
y = mx.zeros((x.shape[0], ne, x.shape[-1])) y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.mlp): for e, expert in enumerate(self.mlp):
idx1, idx2 = map(mx.array, np.where(inds == e)) idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0: if idx1.size == 0:
@ -141,7 +139,7 @@ class ParallelBlock(nn.Module):
dims = config.model_dim dims = config.model_dim
mlp_dims = dims * 4 mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims) self.ln = nn.LayerNorm(dims)
self.moe = MOE(config, dims, mlp_dims) self.moe = MOE(config, dims, mlp_dims)
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
@ -179,7 +177,7 @@ class Embd(nn.Module):
class OutputHead(nn.Module): class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
self.ln = LayerNorm(config.model_dim) self.ln = nn.LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab) self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs): def __call__(self, inputs):

View File

@ -6,7 +6,6 @@ import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -82,9 +81,6 @@ class Attention(nn.Module):
# expand shared kv # expand shared kv
assert self.k_num_heads == self.v_num_heads assert self.k_num_heads == self.v_num_heads
repeats = self.config.n_shared_head
key_states = mx.repeat(key_states, repeats, axis=1)
value_states = mx.repeat(value_states, repeats, axis=1)
kv_seq_len = 0 kv_seq_len = 0
if cache is not None: if cache is not None:
@ -97,12 +93,14 @@ class Attention(nn.Module):
key_states = mx.concatenate([cache[0], key_states], axis=2) key_states = mx.concatenate([cache[0], key_states], axis=2)
value_states = mx.concatenate([cache[1], value_states], axis=2) value_states = mx.concatenate([cache[1], value_states], axis=2)
scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if attention_mask is not None: query_states,
scores += attention_mask key_states,
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) value_states,
output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output), (key_states, value_states) return self.o_proj(output), (key_states, value_states)
@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Attention(config) self.self_attn = Attention(config)
self.mlp = MLP(config) self.mlp = MLP(config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__( def __call__(
self, self,
@ -170,7 +168,7 @@ class PlamoModel(nn.Module):
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore self.layers = PlamoDecoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__( def __call__(
self, self,

View File

@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -102,9 +101,9 @@ class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args) self.attn = Attention(args)
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args) self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None): def __call__(self, x, mask=None, cache=None):
@ -124,7 +123,7 @@ class QwenModel(nn.Module):
super().__init__() super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size) self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, inputs, mask=None, cache=None): def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs) x = self.wte(inputs)

View File

@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass @dataclass
@ -114,8 +113,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args self.args = args
def __call__( def __call__(
@ -142,7 +143,7 @@ class Qwen2Model(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__( def __call__(
self, self,

View File

@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass @dataclass
@ -120,8 +119,10 @@ class DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.self_attn = Attention(config=config) self.self_attn = Attention(config=config)
self.mlp = MLP(config.hidden_size, config.intermediate_size) self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(
self.post_attention_layernorm = LayerNorm( config.hidden_size, eps=config.layer_norm_eps
)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps config.hidden_size, eps=config.layer_norm_eps
) )
@ -138,7 +139,7 @@ class StableLM(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)] self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def __call__(self, x, mask, cache): def __call__(self, x, mask, cache):
x = self.embed_tokens(x) x = self.embed_tokens(x)

View File

@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass @dataclass
@ -91,8 +90,8 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.norm_epsilon args.hidden_size, eps=args.norm_epsilon
) )
self.args = args self.args = args
@ -121,7 +120,7 @@ class Starcoder2Model(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
def __call__( def __call__(
self, self,

View File

@ -1,4 +1,4 @@
mlx>=0.6 mlx>=0.8
numpy numpy
transformers>=4.38.0 transformers>=4.38.0
protobuf protobuf

View File

@ -132,21 +132,6 @@ class MultiHeadAttention(nn.Module):
return self.out_proj(values_hat), (keys, values) return self.out_proj(values_hat), (keys, values)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
t = x.dtype
output = self._norm(x).astype(t)
return self.weight * output
class DenseActivation(nn.Module): class DenseActivation(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
@ -182,8 +167,8 @@ class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.attention = MultiHeadAttention(config) self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config) self.dense = DenseActivation(config)
def __call__(self, x, mask): def __call__(self, x, mask):
@ -202,7 +187,7 @@ class TransformerEncoder(nn.Module):
self.layers = [ self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers) TransformerEncoderLayer(config) for i in range(config.num_layers)
] ]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array): def __call__(self, x: mx.array):
@ -217,9 +202,9 @@ class TransformerDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.self_attention = MultiHeadAttention(config) self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config) self.dense = DenseActivation(config)
def __call__( def __call__(
@ -257,7 +242,7 @@ class TransformerDecoder(nn.Module):
super().__init__() super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers) n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, cache=None): def __call__(self, x, memory, cache=None):

View File

@ -1,3 +1,3 @@
mlx>=0.0.6 mlx>=0.8.0
transformers transformers
numpy numpy

View File

@ -21,7 +21,9 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t) self.assertEqual(outputs.dtype, t)
outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache) outputs, cache = model(
mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache
)
self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t) self.assertEqual(outputs.dtype, t)

View File

@ -2,8 +2,12 @@
This is an example of using MLX to fine-tune an LLM with low rank adaptation This is an example of using MLX to fine-tune an LLM with low rank adaptation
(LoRA) for a target task.[^lora] The example also supports quantized LoRA (LoRA) for a target task.[^lora] The example also supports quantized LoRA
(QLoRA).[^qlora] The example works with Llama, Mistral, and Phi-2 style (QLoRA).[^qlora] The example works with Llama and Mistral style models
models available on Hugging Face. available on Hugging Face.
> [!TIP]
> For a more fully featured LLM package, checkout [MLX
> LM](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm).
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to generate SQL queries from natural language. However, the example is intended to

View File

@ -1,10 +1,11 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import copy import copy
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import models
import utils import utils
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
@ -12,11 +13,8 @@ from mlx.utils import tree_flatten
def quantize(weights, config, args): def quantize(weights, config, args):
quantized_config = copy.deepcopy(config) quantized_config = copy.deepcopy(config)
# Get model classes
model_class, model_args_class = utils._get_classes(config=config)
# Load the model: # Load the model:
model = model_class(model_args_class.from_dict(config)) model = models.Model(models.ModelArgs.from_dict(config))
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()))
# Quantize the model: # Quantize the model:

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
from pathlib import Path from pathlib import Path
@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import utils import utils
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear from models import LoRALinear
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
import json import json
@ -12,7 +12,7 @@ import mlx.optimizers as optim
import numpy as np import numpy as np
import utils as lora_utils import utils as lora_utils
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear from models import LoRALinear
def build_parser(): def build_parser():

View File

@ -2,17 +2,13 @@
import glob import glob
import inspect import inspect
import json
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
@dataclass @dataclass
@ -134,20 +130,6 @@ class LoRALinear(nn.Module):
return y + self.scale * z return y + self.scale * z
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -192,13 +174,6 @@ class Attention(nn.Module):
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None: if cache is not None:
key_cache, value_cache = cache key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2]) queries = self.rope(queries, offset=key_cache.shape[2])
@ -209,11 +184,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) output = mx.fast.scaled_dot_product_attention(
if mask is not None: queries, keys, values, scale=self.scale, mask=mask
scores += mask )
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values) return self.o_proj(output), (keys, values)
@ -235,8 +209,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(args) self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size) self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args self.args = args
def __call__( def __call__(
@ -263,7 +239,7 @@ class LlamaModel(nn.Module):
self.layers = [ self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
] ]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__( def __call__(
self, self,
@ -299,60 +275,3 @@ class Model(nn.Module):
): ):
out, cache = self.model(inputs, cache) out, cache = self.model(inputs, cache)
return self.lm_head(out), cache return self.lm_head(out), cache
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer, config
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y

View File

@ -1,15 +0,0 @@
import inspect
from dataclasses import dataclass
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)

View File

@ -1,202 +0,0 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

View File

@ -1,86 +0,0 @@
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
def to_linear(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized and not de_quantize:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
lora_rank: int = 8,
bias: bool = False,
scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, lora_rank),
)
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + self.scale * z

View File

@ -1,253 +0,0 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
vocab_size: int = 32000
max_position_embeddings: int = 4096 * 32
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_experts_per_tok: int = 2
num_key_value_heads: int = 8
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
vocab_size: int
rope_theta: float = 1e6
rope_traditional: bool = False
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.max_position_embeddings = args.max_position_embeddings
self.rope_theta = args.rope_theta
self.repeats = self.num_heads // self.num_key_value_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
self.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.num_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MixtralBLockSparseTop2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn_dim = args.intermediate_size
self.hidden_dim = args.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = [
MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts)
]
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne])
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
return y.reshape(orig_shape)
class MixtralDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.block_sparse_moe(self.post_attention_layernorm(h))
out = h + r
return out, cache
class MixtralModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

View File

@ -1,138 +0,0 @@
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
n_positions: int = 2048
vocab_size: int = 51200
n_embd: int = 2560
n_head: int = 32
n_layer: int = 32
rotary_dim: int = 32
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, n_head: int, rotary_dim: int):
super().__init__()
self.n_head = n_head
self.q_proj = nn.Linear(dims, dims)
self.k_proj = nn.Linear(dims, dims)
self.v_proj = nn.Linear(dims, dims)
self.dense = nn.Linear(dims, dims)
self.rope = nn.RoPE(rotary_dim, traditional=False)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
n_head = self.n_head
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(values_hat), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.n_embd
mlp_dims = dims * 4
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
self.input_layernorm = LayerNorm(dims)
self.mlp = MLP(dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class Transformer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
self.final_layernorm = LayerNorm(config.n_embd)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.final_layernorm(x), cache
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model = Transformer(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache

View File

@ -1,3 +1,3 @@
mlx>=0.0.7 mlx>=0.8.0
transformers transformers
numpy numpy

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import glob import glob
import json import json
@ -8,40 +8,10 @@ from typing import Generator
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import models.llama as llama import models
import models.mixtral as mixtral
import models.phi2 as phi2
import transformers import transformers
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
# Constants
MODEL_MAPPING = {
"llama": llama,
"mistral": llama, # mistral is compatible with llama
"phi": phi2,
"mixtral": mixtral,
}
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
if model_type not in MODEL_MAPPING:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
arch = MODEL_MAPPING[model_type]
return arch.Model, arch.ModelArgs
def fetch_from_hub(hf_path: str): def fetch_from_hub(hf_path: str):
model_path = snapshot_download( model_path = snapshot_download(
@ -157,9 +127,8 @@ def load(path_or_hf_repo: str):
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf).items()) weights.update(mx.load(wf).items())
model_class, model_args_class = _get_classes(config=config) model_args = models.ModelArgs.from_dict(config)
model_args = model_args_class.from_dict(config) model = models.Model(model_args)
model = model_class(model_args)
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module( nn.QuantizedLinear.quantize_module(
model, model,

View File

@ -1,3 +1,3 @@
mlx mlx>=0.8.0
numpy numpy
transformers transformers

View File

@ -134,21 +134,6 @@ class MultiHeadAttention(nn.Module):
return self.out_proj(values_hat), (keys, values) return self.out_proj(values_hat), (keys, values)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
t = x.dtype
output = self._norm(x).astype(t)
return self.weight * output
class DenseActivation(nn.Module): class DenseActivation(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
@ -184,8 +169,8 @@ class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.attention = MultiHeadAttention(config) self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config) self.dense = DenseActivation(config)
def __call__(self, x, mask): def __call__(self, x, mask):
@ -204,7 +189,7 @@ class TransformerEncoder(nn.Module):
self.layers = [ self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers) TransformerEncoderLayer(config) for i in range(config.num_layers)
] ]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array): def __call__(self, x: mx.array):
@ -219,9 +204,9 @@ class TransformerDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.self_attention = MultiHeadAttention(config) self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config) self.dense = DenseActivation(config)
def __call__( def __call__(
@ -252,7 +237,7 @@ class TransformerDecoder(nn.Module):
super().__init__() super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers) n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, mask, memory_mask, cache=None): def __call__(self, x, memory, mask, memory_mask, cache=None):

View File

@ -1,4 +1,4 @@
mlx>=0.1 mlx>=0.8
numba numba
numpy numpy
torch torch

View File

@ -37,11 +37,6 @@ def sinusoids(length, channels, max_timescale=10000):
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int): def __init__(self, n_state: int, n_head: int):
super().__init__() super().__init__()
@ -98,17 +93,17 @@ class ResidualAttentionBlock(nn.Module):
super().__init__() super().__init__()
self.attn = MultiHeadAttention(n_state, n_head) self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state) self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = ( self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None MultiHeadAttention(n_state, n_head) if cross_attention else None
) )
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4 n_mlp = n_state * 4
self.mlp1 = nn.Linear(n_state, n_mlp) self.mlp1 = nn.Linear(n_state, n_mlp)
self.mlp2 = nn.Linear(n_mlp, n_state) self.mlp2 = nn.Linear(n_mlp, n_state)
self.mlp_ln = LayerNorm(n_state) self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, kv_cache=None): def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None) kv, cross_kv = kv_cache if kv_cache else (None, None)
@ -140,7 +135,7 @@ class AudioEncoder(nn.Module):
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype) self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
self.ln_post = LayerNorm(n_state) self.ln_post = nn.LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)).astype(x.dtype) x = nn.gelu(self.conv1(x)).astype(x.dtype)
@ -174,7 +169,7 @@ class TextDecoder(nn.Module):
ResidualAttentionBlock(n_state, n_head, cross_attention=True) ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer) for _ in range(n_layer)
] ]
self.ln = LayerNorm(n_state) self.ln = nn.LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype( self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
dtype dtype
) )