diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index e67fe473..82960423 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -1,4 +1,4 @@ from .convert import convert from .utils import generate, load -__version__ = "0.0.13" +__version__ = "0.0.14" diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index cd667d37..2bc782b7 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from functools import partial from typing import Dict, Optional, Tuple, Union import mlx.core as mx @@ -22,18 +23,21 @@ class ModelArgs(BaseModelArgs): rope_traditional: bool = False +@partial(mx.compile, shapeless=True) +def rms_norm(x, weight, eps): + x = x.astype(mx.float32) + x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) + return (1.0 + weight) * x.astype(weight.dtype) + + class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() self.weight = mx.ones((dims,)) self.eps = eps - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return (1 + self.weight) * output + return rms_norm(x, self.weight, self.eps) class Attention(nn.Module): diff --git a/llms/mlx_lm/models/layers.py b/llms/mlx_lm/models/layers.py new file mode 100644 index 00000000..77d9831f --- /dev/null +++ b/llms/mlx_lm/models/layers.py @@ -0,0 +1,51 @@ +from functools import partial + +import mlx.core as mx +import mlx.nn as nn + + +@partial(mx.compile, shapeless=True) +def rms_norm(x, weight, eps): + x = x.astype(mx.float32) + x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) + return weight * x.astype(weight.dtype) + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def __call__(self, x): + return rms_norm(x, self.weight, self.eps) + + +@partial(mx.compile, shapeless=True) +def ln_norm(x, eps, weight=None, bias=None): + t = x.dtype + x = x.astype(mx.float32) + means = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - means) * mx.rsqrt(var + eps) + x = x.astype(t) + return weight * x + bias if weight is not None else x + + +class LayerNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True): + super().__init__() + if affine: + self.bias = mx.zeros((dims,)) + self.weight = mx.ones((dims,)) + self.eps = eps + self.dims = dims + + def _extra_repr(self): + return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" + + def __call__(self, x: mx.array) -> mx.array: + if "weight" in self: + return ln_norm(x, self.eps, self.weight, self.bias) + else: + return ln_norm(x, self.eps) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index a38db95b..66105896 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .layers import RMSNorm @dataclass @@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs): raise ValueError("rope_scaling 'type' currently only supports 'linear'") -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index f584d509..c2ddcb7c 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -6,6 +6,7 @@ import mlx.nn as nn import numpy as np from .base import BaseModelArgs +from .layers import RMSNorm @dataclass @@ -30,20 +31,6 @@ class ModelArgs(BaseModelArgs): self.num_key_value_heads = self.num_attention_heads -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class MixtralAttention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index f9fe1475..f97ce6f9 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -6,6 +6,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .layers import LayerNorm try: import hf_olmo @@ -37,11 +38,6 @@ class ModelArgs(BaseModelArgs): ) -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 84c5d4f9..85d16759 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -6,6 +6,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .layers import LayerNorm @dataclass @@ -27,11 +28,6 @@ class ModelArgs(BaseModelArgs): self.num_key_value_heads = self.num_attention_heads -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - class PhiAttention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 9cd23997..8537645a 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -1,17 +1,13 @@ -import glob import inspect -import json import math -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Tuple import mlx.core as mx import mlx.nn as nn import numpy as np -from huggingface_hub import snapshot_download -from mlx.utils import tree_unflatten -from transformers import AutoTokenizer + +from .layers import LayerNorm @dataclass @@ -37,11 +33,6 @@ class ModelArgs: ) -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - class RoPEAttention(nn.Module): def __init__(self, dims: int, num_heads: int, rotary_dim: int): super().__init__() diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index b6ca2491..ba026335 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -6,6 +6,7 @@ import mlx.nn as nn import numpy as np from .base import BaseModelArgs +from .layers import RMSNorm @dataclass @@ -22,20 +23,6 @@ class ModelArgs(BaseModelArgs): rope_traditional: bool = False -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.variance_epsilon = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.variance_epsilon) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index aeda9c32..16609414 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .layers import RMSNorm @dataclass @@ -26,20 +27,6 @@ class ModelArgs(BaseModelArgs): self.num_key_value_heads = self.num_attention_heads -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 59aa6918..f0c19171 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .layers import RMSNorm @dataclass @@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs): raise ValueError("rope_scaling 'type' currently only supports 'linear'") -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-6): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/models/stablelm_epoch.py b/llms/mlx_lm/models/stablelm_epoch.py index 2f88bd03..6b13012d 100644 --- a/llms/mlx_lm/models/stablelm_epoch.py +++ b/llms/mlx_lm/models/stablelm_epoch.py @@ -6,6 +6,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .layers import LayerNorm @dataclass @@ -24,11 +25,6 @@ class ModelArgs(BaseModelArgs): use_qkv_bias: bool -class LayerNorm(nn.LayerNorm): - def __call__(self, x: mx.array) -> mx.array: - return super().__call__(x.astype(mx.float32)).astype(x.dtype) - - class Attention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index e9d4d68e..049049e7 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.1 +mlx>=0.4 numpy transformers>=4.38.0 protobuf