diff --git a/llava/language.py b/llava/language.py index e9023b99..7c96c676 100644 --- a/llava/language.py +++ b/llava/language.py @@ -45,20 +45,6 @@ class TextConfig: raise ValueError("rope_scaling 'type' currently only supports 'linear'") -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, config: TextConfig): super().__init__() @@ -105,10 +91,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -119,11 +101,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -145,8 +126,8 @@ class TransformerBlock(nn.Module): self.hidden_size = config.hidden_size self.self_attn = Attention(config) self.mlp = MLP(config.hidden_size, config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.config = config @@ -175,7 +156,7 @@ class Llama(nn.Module): self.layers = [ TransformerBlock(config=config) for _ in range(config.num_hidden_layers) ] - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( self, diff --git a/llava/requirements.txt b/llava/requirements.txt index a11d9148..4d904e4e 100644 --- a/llava/requirements.txt +++ b/llava/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.5.0 +mlx>=0.8.0 numpy transformers torch diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index e60b60d5..2a1f3435 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -49,20 +49,6 @@ class ModelArgs: ) -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -107,10 +93,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -121,11 +103,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -147,8 +128,10 @@ class TransformerBlock(nn.Module): self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) self.args = args def __call__( @@ -175,7 +158,7 @@ class LlamaModel(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, diff --git a/llms/gguf_llm/requirements.txt b/llms/gguf_llm/requirements.txt index f5921e36..656fe028 100644 --- a/llms/gguf_llm/requirements.txt +++ b/llms/gguf_llm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.0.11 +mlx>=0.8 numpy protobuf==3.20.2 sentencepiece diff --git a/llms/llama/llama.py b/llms/llama/llama.py index fb12282b..3e5f78a1 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -28,20 +28,6 @@ class ModelArgs: rope_traditional: bool = True -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -120,8 +106,8 @@ class TransformerBlock(nn.Module): self.dim = args.dim self.attention = Attention(args) self.feed_forward = FeedForward(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) self.args = args def __call__( @@ -144,7 +130,7 @@ class Llama(nn.Module): self.vocab_size = args.vocab_size self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__(self, x): diff --git a/llms/llama/requirements.txt b/llms/llama/requirements.txt index 755af473..6b458abc 100644 --- a/llms/llama/requirements.txt +++ b/llms/llama/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.0.6 +mlx>=0.8.0 sentencepiece torch numpy diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 39456d97..24ae730d 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -26,20 +26,6 @@ class ModelArgs: rope_theta: float = 10000 -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -73,9 +59,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -86,11 +69,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output), (keys, values) @@ -113,8 +95,8 @@ class TransformerBlock(nn.Module): self.dim = args.dim self.attention = Attention(args) self.feed_forward = FeedForward(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) self.args = args def __call__( @@ -139,7 +121,7 @@ class Mistral(nn.Module): assert self.vocab_size > 0 self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__( diff --git a/llms/mistral/requirements.txt b/llms/mistral/requirements.txt index 755af473..6b458abc 100644 --- a/llms/mistral/requirements.txt +++ b/llms/mistral/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.0.6 +mlx>=0.8.0 sentencepiece torch numpy diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index bb7b5238..d5e07926 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -26,20 +26,6 @@ class ModelArgs: moe: dict = None -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -73,9 +59,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -86,11 +69,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output), (keys, values) @@ -144,8 +126,8 @@ class MOETransformerBlock(nn.Module): self.dim = args.dim self.attention = Attention(args) self.feed_forward = MOEFeedForward(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) self.args = args def __call__( @@ -170,7 +152,7 @@ class Mixtral(nn.Module): assert self.vocab_size > 0 self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.layers = [MOETransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__( diff --git a/llms/mixtral/requirements.txt b/llms/mixtral/requirements.txt index d775b88f..6b458abc 100644 --- a/llms/mixtral/requirements.txt +++ b/llms/mixtral/requirements.txt @@ -1,4 +1,4 @@ -mlx +mlx>=0.8.0 sentencepiece torch numpy diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 724d2007..f8df07f2 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -5,7 +5,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import LayerNorm @dataclass @@ -97,7 +96,7 @@ class TransformerBlock(nn.Module): self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = LayerNorm( + self.input_layernorm = nn.LayerNorm( args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias ) self.args = args @@ -125,7 +124,7 @@ class CohereModel(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = LayerNorm( + self.norm = nn.LayerNorm( args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias ) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 6b35f257..0ab99e58 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -23,13 +23,6 @@ class ModelArgs(BaseModelArgs): rope_traditional: bool = False -@partial(mx.compile, shapeless=True) -def rms_norm(x, weight, eps): - x = x.astype(mx.float32) - x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) - return (1.0 + weight) * x.astype(weight.dtype) - - class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -37,7 +30,7 @@ class RMSNorm(nn.Module): self.eps = eps def __call__(self, x): - return rms_norm(x, self.weight, self.eps) + return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) class Attention(nn.Module): diff --git a/llms/mlx_lm/models/layers.py b/llms/mlx_lm/models/layers.py deleted file mode 100644 index cf91ad19..00000000 --- a/llms/mlx_lm/models/layers.py +++ /dev/null @@ -1,80 +0,0 @@ -from functools import partial - -import mlx.core as mx -import mlx.nn as nn - - -@partial(mx.compile, shapeless=True) -def rms_norm(x, weight, eps): - x = x.astype(mx.float32) - x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) - return weight * x.astype(weight.dtype) - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def __call__(self, x): - return rms_norm(x, self.weight, self.eps) - - -@partial(mx.compile, shapeless=True) -def ln_norm(x, eps, weight=None, bias=None): - """ - Layer normalization for input tensor x. - - Args: - x (np.ndarray): Input tensor. - eps (float, optional): Small value to avoid division by zero. - weight (np.ndarray, optional): Weight tensor for normalization. - bias (np.ndarray, optional): Bias tensor for normalization. - - Returns: - np.ndarray: Normalized tensor. - """ - t = x.dtype - x = x.astype(mx.float32) - - # Compute mean and variance along the last dimension - means = mx.mean(x, axis=-1, keepdims=True) - var = mx.var(x, axis=-1, keepdims=True) - - # Normalize the input tensor - x = (x - means) * mx.rsqrt(var + eps) - x = x.astype(t) - - # Apply weight and bias if provided - if weight is not None: - x = x * weight - if bias is not None: - x = x + bias - return x - - -class LayerNorm(nn.Module): - def __init__( - self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True - ): - super().__init__() - self.eps = eps - self.dims = dims - self.affine = affine - - if affine: - self.weight = mx.ones((dims,)) - self.bias = mx.zeros((dims,)) if bias else None - - def _extra_repr(self): - return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" - - def __call__(self, x: mx.array) -> mx.array: - if self.affine: - if self.bias is not None: - return ln_norm(x, self.eps, self.weight, self.bias) - else: - return ln_norm(x, self.eps, self.weight) - else: - return ln_norm(x, self.eps) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 972f381e..b74dbe3b 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -5,7 +5,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import RMSNorm @dataclass @@ -113,8 +112,10 @@ class TransformerBlock(nn.Module): self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) self.args = args def __call__( @@ -141,7 +142,7 @@ class LlamaModel(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index c2c9cb94..26a56779 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -6,7 +6,6 @@ import mlx.nn as nn import numpy as np from .base import BaseModelArgs -from .layers import RMSNorm @dataclass @@ -146,7 +145,7 @@ class MixtralSparseMoeBlock(nn.Module): if self.training: mx.eval(inds) inds = np.array(inds) - y = mx.zeros((x.shape[0], ne, x.shape[-1])) + y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) for e, expert in enumerate(self.experts): idx1, idx2 = map(mx.array, np.where(inds == e)) if idx1.size == 0: @@ -173,8 +172,10 @@ class MixtralDecoderLayer(nn.Module): self.self_attn = MixtralAttention(args) self.block_sparse_moe = MixtralSparseMoeBlock(args) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) def __call__( self, @@ -199,7 +200,7 @@ class MixtralModel(nn.Module): self.layers = [ MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 548457d6..b2ceec37 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -6,7 +6,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import LayerNorm try: import hf_olmo @@ -46,8 +45,8 @@ class TransformerBlock(nn.Module): self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False) self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False) - self.att_norm = LayerNorm(dim, affine=False) - self.ff_norm = LayerNorm(dim, affine=False) + self.att_norm = nn.LayerNorm(dim, affine=False) + self.ff_norm = nn.LayerNorm(dim, affine=False) head_dim = dim // self.n_heads self.scale = head_dim**-0.5 @@ -120,7 +119,7 @@ class Transformer(nn.Module): self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] if not self.weight_tying: self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) - self.norm = LayerNorm(args.d_model, affine=False) + self.norm = nn.LayerNorm(args.d_model, affine=False) def __call__( self, diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 3d5a659e..91b97023 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -6,7 +6,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import LayerNorm @dataclass @@ -122,7 +121,9 @@ class PhiDecoderLayer(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.self_attn = PhiAttention(config=config) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) self.mlp = PhiMLP(config) def __call__(self, x, mask, cache): @@ -137,7 +138,9 @@ class PhiModel(nn.Module): super().__init__() self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)] - self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def __call__(self, x, cache): x = self.embed_tokens(x) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 0f2c8369..3849a70b 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -7,8 +7,6 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .layers import LayerNorm - @dataclass class ModelArgs: @@ -116,7 +114,7 @@ class MOE(nn.Module): if self.training: ys = [] - y = mx.zeros((x.shape[0], ne, x.shape[-1])) + y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype) for e, expert in enumerate(self.mlp): idx1, idx2 = map(mx.array, np.where(inds == e)) if idx1.size == 0: @@ -141,7 +139,7 @@ class ParallelBlock(nn.Module): dims = config.model_dim mlp_dims = dims * 4 self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) - self.ln = LayerNorm(dims) + self.ln = nn.LayerNorm(dims) self.moe = MOE(config, dims, mlp_dims) def __call__(self, x, mask, cache): @@ -179,7 +177,7 @@ class Embd(nn.Module): class OutputHead(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() - self.ln = LayerNorm(config.model_dim) + self.ln = nn.LayerNorm(config.model_dim) self.linear = nn.Linear(config.model_dim, config.num_vocab) def __call__(self, inputs): diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index c0a32648..c4a87a1e 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -6,7 +6,6 @@ import mlx.nn as nn import numpy as np from .base import BaseModelArgs -from .layers import RMSNorm @dataclass @@ -82,9 +81,6 @@ class Attention(nn.Module): # expand shared kv assert self.k_num_heads == self.v_num_heads - repeats = self.config.n_shared_head - key_states = mx.repeat(key_states, repeats, axis=1) - value_states = mx.repeat(value_states, repeats, axis=1) kv_seq_len = 0 if cache is not None: @@ -97,12 +93,14 @@ class Attention(nn.Module): key_states = mx.concatenate([cache[0], key_states], axis=2) value_states = mx.concatenate([cache[1], value_states], axis=2) - scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2) - if attention_mask is not None: - scores += attention_mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) - + output = mx.fast.scaled_dot_product_attention( + query_states, + key_states, + value_states, + scale=self.scale, + mask=attention_mask, + ) + output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) return self.o_proj(output), (key_states, value_states) @@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module): self.hidden_size = config.hidden_size self.self_attn = Attention(config) self.mlp = MLP(config) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( self, @@ -170,7 +168,7 @@ class PlamoModel(nn.Module): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = PlamoDecoder(config) # type: ignore - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( self, diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 5fe02e98..a4e82dd2 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,7 +5,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import RMSNorm @dataclass @@ -102,9 +101,9 @@ class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.attn = Attention(args) - self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.mlp = MLP(args) def __call__(self, x, mask=None, cache=None): @@ -124,7 +123,7 @@ class QwenModel(nn.Module): super().__init__() self.wte = nn.Embedding(args.vocab_size, args.hidden_size) self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] - self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 1f196b0f..1e694b20 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -5,7 +5,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import RMSNorm @dataclass @@ -114,8 +113,10 @@ class TransformerBlock(nn.Module): self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) self.args = args def __call__( @@ -142,7 +143,7 @@ class Qwen2Model(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 73b01961..f685c76d 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -6,7 +6,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import LayerNorm @dataclass @@ -120,8 +119,10 @@ class DecoderLayer(nn.Module): super().__init__() self.self_attn = Attention(config=config) self.mlp = MLP(config.hidden_size, config.intermediate_size) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.post_attention_layernorm = LayerNorm( + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.post_attention_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) @@ -138,7 +139,7 @@ class StableLM(nn.Module): super().__init__() self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)] - self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def __call__(self, x, mask, cache): x = self.embed_tokens(x) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 6b10f716..f18160a5 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -5,7 +5,6 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .layers import LayerNorm @dataclass @@ -91,8 +90,8 @@ class TransformerBlock(nn.Module): self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) - self.post_attention_layernorm = LayerNorm( + self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( args.hidden_size, eps=args.norm_epsilon ) self.args = args @@ -121,7 +120,7 @@ class Starcoder2Model(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon) + self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon) def __call__( self, diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 040fa864..80421d75 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.6 +mlx>=0.8 numpy transformers>=4.38.0 protobuf diff --git a/llms/speculative_decoding/model.py b/llms/speculative_decoding/model.py index ed4a7d77..5ce5c300 100644 --- a/llms/speculative_decoding/model.py +++ b/llms/speculative_decoding/model.py @@ -132,21 +132,6 @@ class MultiHeadAttention(nn.Module): return self.out_proj(values_hat), (keys, values) -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - t = x.dtype - output = self._norm(x).astype(t) - return self.weight * output - - class DenseActivation(nn.Module): def __init__(self, config: T5Config): super().__init__() @@ -182,8 +167,8 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, config: T5Config): super().__init__() self.attention = MultiHeadAttention(config) - self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.dense = DenseActivation(config) def __call__(self, x, mask): @@ -202,7 +187,7 @@ class TransformerEncoder(nn.Module): self.layers = [ TransformerEncoderLayer(config) for i in range(config.num_layers) ] - self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) def __call__(self, x: mx.array): @@ -217,9 +202,9 @@ class TransformerDecoderLayer(nn.Module): super().__init__() self.self_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config) - self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.dense = DenseActivation(config) def __call__( @@ -257,7 +242,7 @@ class TransformerDecoder(nn.Module): super().__init__() n_layers = getattr(config, "num_decoder_layers", config.num_layers) self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] - self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) def __call__(self, x, memory, cache=None): diff --git a/llms/speculative_decoding/requirements.txt b/llms/speculative_decoding/requirements.txt index 501c713c..12199383 100644 --- a/llms/speculative_decoding/requirements.txt +++ b/llms/speculative_decoding/requirements.txt @@ -1,3 +1,3 @@ -mlx>=0.0.6 +mlx>=0.8.0 transformers numpy diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index fbe1bfeb..865d419d 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -21,7 +21,9 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache) + outputs, cache = model( + mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache + ) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t) diff --git a/lora/README.md b/lora/README.md index 9c379679..5c4b5ee8 100644 --- a/lora/README.md +++ b/lora/README.md @@ -2,8 +2,12 @@ This is an example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task.[^lora] The example also supports quantized LoRA -(QLoRA).[^qlora] The example works with Llama, Mistral, and Phi-2 style -models available on Hugging Face. +(QLoRA).[^qlora] The example works with Llama and Mistral style models +available on Hugging Face. + +> [!TIP] +> For a more fully featured LLM package, checkout [MLX +> LM](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm). In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to generate SQL queries from natural language. However, the example is intended to diff --git a/lora/convert.py b/lora/convert.py index bc85eb5e..26928c96 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -1,10 +1,11 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse import copy import mlx.core as mx import mlx.nn as nn +import models import utils from mlx.utils import tree_flatten @@ -12,11 +13,8 @@ from mlx.utils import tree_flatten def quantize(weights, config, args): quantized_config = copy.deepcopy(config) - # Get model classes - model_class, model_args_class = utils._get_classes(config=config) - # Load the model: - model = model_class(model_args_class.from_dict(config)) + model = models.Model(models.ModelArgs.from_dict(config)) model.load_weights(list(weights.items())) # Quantize the model: diff --git a/lora/fuse.py b/lora/fuse.py index a957ff28..6244ecd1 100644 --- a/lora/fuse.py +++ b/lora/fuse.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse from pathlib import Path @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import utils from mlx.utils import tree_flatten, tree_unflatten -from models.lora import LoRALinear +from models import LoRALinear if __name__ == "__main__": parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") diff --git a/lora/lora.py b/lora/lora.py index f21f0d2d..116c0a94 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse import json @@ -12,7 +12,7 @@ import mlx.optimizers as optim import numpy as np import utils as lora_utils from mlx.utils import tree_flatten, tree_unflatten -from models.lora import LoRALinear +from models import LoRALinear def build_parser(): diff --git a/lora/models.py b/lora/models.py index 293b4f96..587cf3f7 100644 --- a/lora/models.py +++ b/lora/models.py @@ -2,17 +2,13 @@ import glob import inspect -import json import math from dataclasses import dataclass -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn import numpy as np -from huggingface_hub import snapshot_download -from transformers import AutoTokenizer @dataclass @@ -134,20 +130,6 @@ class LoRALinear(nn.Module): return y + self.scale * z -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -192,13 +174,6 @@ class Attention(nn.Module): keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -209,11 +184,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -235,8 +209,10 @@ class TransformerBlock(nn.Module): self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) self.args = args def __call__( @@ -263,7 +239,7 @@ class LlamaModel(nn.Module): self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, @@ -299,60 +275,3 @@ class Model(nn.Module): ): out, cache = self.model(inputs, cache) return self.lm_head(out), cache - - -def load(path_or_hf_repo: str): - # If the path exists, it will try to load model form it - # otherwise download and cache from the hf_repo and cache - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) - ) - - with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - quantization = config.get("quantization", None) - model_args = ModelArgs.from_dict(config) - - weight_files = glob.glob(str(model_path / "*.safetensors")) - if len(weight_files) == 0: - raise FileNotFoundError("No safetensors found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - model = Model(model_args) - if quantization is not None: - nn.QuantizedLinear.quantize_module( - model, - **quantization, - linear_class_predicate=lambda m: isinstance(m, nn.Linear) - and m.weight.shape[0] != 8, - ) - - model.load_weights(list(weights.items())) - - mx.eval(model.parameters()) - tokenizer = AutoTokenizer.from_pretrained(model_path) - return model, tokenizer, config - - -def generate(prompt: mx.array, model: Model, temp: float = 0.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - y = prompt - cache = None - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] - y = sample(logits) - yield y diff --git a/lora/models/__init__.py b/lora/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/lora/models/base.py b/lora/models/base.py deleted file mode 100644 index d1ea0b2c..00000000 --- a/lora/models/base.py +++ /dev/null @@ -1,15 +0,0 @@ -import inspect -from dataclasses import dataclass - - -@dataclass -class BaseModelArgs: - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) diff --git a/lora/models/llama.py b/lora/models/llama.py deleted file mode 100644 index ee026363..00000000 --- a/lora/models/llama.py +++ /dev/null @@ -1,202 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - model_type: str = None - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.repeats = n_heads // n_kv_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.embed_tokens(inputs) - - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - - return self.norm(h), cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model = LlamaModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache diff --git a/lora/models/lora.py b/lora/models/lora.py deleted file mode 100644 index 8f3c01eb..00000000 --- a/lora/models/lora.py +++ /dev/null @@ -1,86 +0,0 @@ -import math - -import mlx.core as mx -import mlx.nn as nn - - -class LoRALinear(nn.Module): - @staticmethod - def from_linear(linear: nn.Linear, rank: int = 8): - # TODO remove when input_dims and output_dims are attributes - # on linear and quantized linear - output_dims, input_dims = linear.weight.shape - if isinstance(linear, nn.QuantizedLinear): - input_dims *= 32 // linear.bits - lora_lin = LoRALinear(input_dims, output_dims, rank) - lora_lin.linear = linear - return lora_lin - - def to_linear(self, de_quantize: bool = False): - linear = self.linear - bias = "bias" in linear - weight = linear.weight - is_quantized = isinstance(linear, nn.QuantizedLinear) - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - if is_quantized: - dtype = mx.float16 - weight = mx.dequantize( - weight, - linear.scales, - linear.biases, - linear.group_size, - linear.bits, - ) - output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=bias) - - lora_b = (self.scale * self.lora_b.T).astype(dtype) - lora_a = self.lora_a.T.astype(dtype) - fused_linear.weight = weight + lora_b @ lora_a - if bias: - fused_linear.bias = linear.bias - - if is_quantized and not de_quantize: - fused_linear = nn.QuantizedLinear.from_linear( - fused_linear, - linear.group_size, - linear.bits, - ) - - return fused_linear - - def __init__( - self, - input_dims: int, - output_dims: int, - lora_rank: int = 8, - bias: bool = False, - scale: float = 20.0, - ): - super().__init__() - - # Regular linear layer weights - self.linear = nn.Linear(input_dims, output_dims, bias=bias) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(input_dims, lora_rank), - ) - self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) - - def __call__(self, x): - dtype = self.linear.weight.dtype - if isinstance(self.linear, nn.QuantizedLinear): - dtype = self.linear.scales.dtype - y = self.linear(x.astype(dtype)) - z = (x @ self.lora_a) @ self.lora_b - return y + self.scale * z diff --git a/lora/models/mixtral.py b/lora/models/mixtral.py deleted file mode 100644 index e70e0d2f..00000000 --- a/lora/models/mixtral.py +++ /dev/null @@ -1,253 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - vocab_size: int = 32000 - max_position_embeddings: int = 4096 * 32 - hidden_size: int = 4096 - intermediate_size: int = 14336 - num_hidden_layers: int = 32 - num_attention_heads: int = 32 - num_experts_per_tok: int = 2 - num_key_value_heads: int = 8 - num_local_experts: int = 8 - rms_norm_eps: float = 1e-5 - vocab_size: int - rope_theta: float = 1e6 - rope_traditional: bool = False - model_type: str = None - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - -class MixtralAttention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.num_heads = args.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = args.num_key_value_heads - self.max_position_embeddings = args.max_position_embeddings - self.rope_theta = args.rope_theta - - self.repeats = self.num_heads // self.num_key_value_heads - - self.scale = self.head_dim**-0.5 - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.rope = nn.RoPE( - self.head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 - ) - - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.num_heads, L, -1]) - - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) - - -class MixtralBLockSparseTop2MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.ffn_dim = args.intermediate_size - self.hidden_dim = args.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = nn.silu - - def __call__(self, x: mx.array) -> mx.array: - current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_dim = args.hidden_size - self.ffn_dim = args.intermediate_size - self.num_experts = args.num_local_experts - self.num_experts_per_tok = args.num_experts_per_tok - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - self.experts = [ - MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts) - ] - - def __call__(self, x: mx.array) -> mx.array: - ne = self.num_experts_per_tok - orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) - - gates = self.gate(x) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) - - scores = mx.softmax( - mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), - axis=-1, - ).astype(gates.dtype) - - mx.eval(inds) - inds = np.array(inds) - y = mx.zeros((x.shape[0], ne, x.shape[-1])) - for e, expert in enumerate(self.experts): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - - y = (y * scores[:, :, None]).sum(axis=1) - - return y.reshape(orig_shape) - - -class MixtralDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - - self.self_attn = MixtralAttention(args) - - self.block_sparse_moe = MixtralSparseMoeBlock(args) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.block_sparse_moe(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class MixtralModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.embed_tokens(inputs) - - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - - return self.norm(h), cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model = MixtralModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache diff --git a/lora/models/phi2.py b/lora/models/phi2.py deleted file mode 100644 index 51b5e390..00000000 --- a/lora/models/phi2.py +++ /dev/null @@ -1,138 +0,0 @@ -import math -from dataclasses import dataclass - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - n_positions: int = 2048 - vocab_size: int = 51200 - n_embd: int = 2560 - n_head: int = 32 - n_layer: int = 32 - rotary_dim: int = 32 - - -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - -class RoPEAttention(nn.Module): - def __init__(self, dims: int, n_head: int, rotary_dim: int): - super().__init__() - - self.n_head = n_head - - self.q_proj = nn.Linear(dims, dims) - self.k_proj = nn.Linear(dims, dims) - self.v_proj = nn.Linear(dims, dims) - self.dense = nn.Linear(dims, dims) - - self.rope = nn.RoPE(rotary_dim, traditional=False) - - def __call__(self, x, mask=None, cache=None): - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Extract some shapes - n_head = self.n_head - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, n_head, -1).transpose(0, 2, 1, 3) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.dense(values_hat), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, dim) - self.act = nn.GELU(approx="precise") - - def __call__(self, x) -> mx.array: - return self.fc2(self.act(self.fc1(x))) - - -class ParallelBlock(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - dims = config.n_embd - mlp_dims = dims * 4 - self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) - self.input_layernorm = LayerNorm(dims) - self.mlp = MLP(dims, mlp_dims) - - def __call__(self, x, mask, cache): - h = self.input_layernorm(x) - attn_h, cache = self.self_attn(h, mask, cache) - ff_h = self.mlp(h) - return attn_h + ff_h + x, cache - - -class Transformer(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) - self.layers = [ParallelBlock(config) for i in range(config.n_layer)] - self.final_layernorm = LayerNorm(config.n_embd) - - def __call__(self, x, mask, cache): - x = self.embed_tokens(x) - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - x, cache[e] = layer(x, mask, cache[e]) - return self.final_layernorm(x), cache - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model = Transformer(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size) - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - - y, cache = self.model(x, mask, cache) - return self.lm_head(y), cache diff --git a/lora/requirements.txt b/lora/requirements.txt index 9abc3e88..12199383 100644 --- a/lora/requirements.txt +++ b/lora/requirements.txt @@ -1,3 +1,3 @@ -mlx>=0.0.7 +mlx>=0.8.0 transformers numpy diff --git a/lora/utils.py b/lora/utils.py index c76b097a..5c791561 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import glob import json @@ -8,40 +8,10 @@ from typing import Generator import mlx.core as mx import mlx.nn as nn -import models.llama as llama -import models.mixtral as mixtral -import models.phi2 as phi2 +import models import transformers from huggingface_hub import snapshot_download -# Constants -MODEL_MAPPING = { - "llama": llama, - "mistral": llama, # mistral is compatible with llama - "phi": phi2, - "mixtral": mixtral, -} - - -def _get_classes(config: dict): - """ - Retrieve the model and model args classes based on the configuration. - - Args: - config (dict): The model configuration. - - Returns: - A tuple containing the Model class and the ModelArgs class. - """ - model_type = config["model_type"] - if model_type not in MODEL_MAPPING: - msg = f"Model type {model_type} not supported." - logging.error(msg) - raise ValueError(msg) - - arch = MODEL_MAPPING[model_type] - return arch.Model, arch.ModelArgs - def fetch_from_hub(hf_path: str): model_path = snapshot_download( @@ -157,9 +127,8 @@ def load(path_or_hf_repo: str): for wf in weight_files: weights.update(mx.load(wf).items()) - model_class, model_args_class = _get_classes(config=config) - model_args = model_args_class.from_dict(config) - model = model_class(model_args) + model_args = models.ModelArgs.from_dict(config) + model = models.Model(model_args) if quantization is not None: nn.QuantizedLinear.quantize_module( model, diff --git a/t5/requirements.txt b/t5/requirements.txt index 4a37303a..a7f031cb 100644 --- a/t5/requirements.txt +++ b/t5/requirements.txt @@ -1,3 +1,3 @@ -mlx +mlx>=0.8.0 numpy transformers diff --git a/t5/t5.py b/t5/t5.py index 556c2503..89f2e486 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -134,21 +134,6 @@ class MultiHeadAttention(nn.Module): return self.out_proj(values_hat), (keys, values) -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - t = x.dtype - output = self._norm(x).astype(t) - return self.weight * output - - class DenseActivation(nn.Module): def __init__(self, config: T5Config): super().__init__() @@ -184,8 +169,8 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, config: T5Config): super().__init__() self.attention = MultiHeadAttention(config) - self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.dense = DenseActivation(config) def __call__(self, x, mask): @@ -204,7 +189,7 @@ class TransformerEncoder(nn.Module): self.layers = [ TransformerEncoderLayer(config) for i in range(config.num_layers) ] - self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) def __call__(self, x: mx.array): @@ -219,9 +204,9 @@ class TransformerDecoderLayer(nn.Module): super().__init__() self.self_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config) - self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.dense = DenseActivation(config) def __call__( @@ -252,7 +237,7 @@ class TransformerDecoder(nn.Module): super().__init__() n_layers = getattr(config, "num_decoder_layers", config.num_layers) self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] - self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) def __call__(self, x, memory, mask, memory_mask, cache=None): diff --git a/whisper/requirements.txt b/whisper/requirements.txt index 23d43200..62f55737 100644 --- a/whisper/requirements.txt +++ b/whisper/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.1 +mlx>=0.8 numba numpy torch diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 183eacc9..37495130 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -37,11 +37,6 @@ def sinusoids(length, channels, max_timescale=10000): return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - class MultiHeadAttention(nn.Module): def __init__(self, n_state: int, n_head: int): super().__init__() @@ -98,17 +93,17 @@ class ResidualAttentionBlock(nn.Module): super().__init__() self.attn = MultiHeadAttention(n_state, n_head) - self.attn_ln = LayerNorm(n_state) + self.attn_ln = nn.LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head) if cross_attention else None ) - self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 self.mlp1 = nn.Linear(n_state, n_mlp) self.mlp2 = nn.Linear(n_mlp, n_state) - self.mlp_ln = LayerNorm(n_state) + self.mlp_ln = nn.LayerNorm(n_state) def __call__(self, x, xa=None, mask=None, kv_cache=None): kv, cross_kv = kv_cache if kv_cache else (None, None) @@ -140,7 +135,7 @@ class AudioEncoder(nn.Module): self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype) self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] - self.ln_post = LayerNorm(n_state) + self.ln_post = nn.LayerNorm(n_state) def __call__(self, x): x = nn.gelu(self.conv1(x)).astype(x.dtype) @@ -174,7 +169,7 @@ class TextDecoder(nn.Module): ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer) ] - self.ln = LayerNorm(n_state) + self.ln = nn.LayerNorm(n_state) self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype( dtype )